<a href="https://colab.research.google.com/github/satani99/attention_is_all_you_need/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
sentence = torch.tensor(
    [0,
     7,
     1,
     2,
     5,
     6,
     4,
     3]
)
sentence

tensor([0, 7, 1, 2, 5, 6, 4, 3])

In [2]:
torch.manual_seed(42)
embed = torch.nn.Embedding(10, 16)
embedded_sentence = embed(sentence).detach()
embedded_sentence.shape

torch.Size([8, 16])

In [3]:
omega = torch.empty(8, 8)
for i, x_i in enumerate(embedded_sentence):
  for j, x_j in enumerate(embedded_sentence):
    omega[i, j] = torch.dot(x_i, x_j)

omega_mat = embedded_sentence.matmul(embedded_sentence.T)

In [4]:
torch.allclose(omega_mat, omega)

True

In [5]:
omega.shape

torch.Size([8, 8])

In [6]:
import torch.nn.functional as F 
attention_weights = F.softmax(omega, dim=1)
attention_weights.shape

torch.Size([8, 8])

In [7]:
attention_weights

tensor([[1.0000e+00, 8.8716e-14, 4.5282e-12, 1.9794e-11, 1.7690e-11, 3.4558e-12,
         1.1407e-08, 3.7097e-12],
        [1.5851e-06, 8.9771e-01, 1.0376e-03, 2.9957e-03, 1.2276e-03, 8.4663e-02,
         7.9979e-05, 1.2282e-02],
        [2.1346e-08, 2.7376e-07, 9.9999e-01, 3.4414e-09, 5.4635e-06, 1.4109e-08,
         8.3542e-06, 1.9875e-07],
        [1.1455e-09, 9.7027e-09, 4.2247e-11, 9.9995e-01, 3.4450e-10, 3.9137e-06,
         7.0079e-10, 4.4827e-05],
        [1.1860e-06, 4.6064e-06, 7.7701e-05, 3.9909e-07, 9.9992e-01, 8.8794e-07,
         1.1313e-09, 1.5174e-07],
        [2.8595e-07, 3.9208e-04, 2.4766e-07, 5.5960e-03, 1.0959e-06, 9.9353e-01,
         3.7887e-07, 4.7552e-04],
        [2.9961e-06, 1.1757e-09, 4.6546e-07, 3.1805e-09, 4.4321e-12, 1.2026e-09,
         1.0000e+00, 8.1439e-09],
        [5.3149e-08, 9.8482e-06, 6.0405e-07, 1.1098e-02, 3.2428e-08, 8.2334e-05,
         4.4424e-07, 9.8881e-01]])

In [8]:
attention_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [9]:
x_2 = embedded_sentence[1, :]

In [10]:
context_vec_2 = torch.zeros(x_2.shape)

In [11]:
for j in range(8):
  x_j = embedded_sentence[j, :]
  context_vec_2 += attention_weights[1, j] * x_j

In [12]:
context_vec_2

tensor([-1.5340, -0.0618, -0.4731,  0.4421,  0.7010,  0.1463, -0.2659,  0.4264,
        -0.0467,  0.2842,  0.1492,  0.7098,  0.9981,  0.4409,  0.6292,  0.3606])

In [13]:
context_vector = torch.matmul(attention_weights, embedded_sentence)

In [14]:
context_vector.shape

torch.Size([8, 16])

In [15]:
torch.allclose(context_vec_2, context_vector[1])

True

##Scaled Dot-product attention

In [16]:
torch.manual_seed(42)
d = embedded_sentence.shape[1]
U_query = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)

In [17]:
x_2 = embedded_sentence[1]
query_2 = U_query.matmul(x_2)

In [18]:
key_2 = U_key.matmul(x_2)
value_2 = U_value.matmul(x_2)

In [19]:
query_2

tensor([ 1.9594,  0.9311,  1.9671,  0.8957,  0.0800,  1.5594, -0.1917,  0.2971,
         0.3863,  1.8167,  0.8888,  1.2970,  0.6111,  0.9745,  1.6844,  0.5390])

In [20]:
keys = U_key.matmul(embedded_sentence.T).T
values = U_value.matmul(embedded_sentence.T).T

In [21]:
keys.shape

torch.Size([8, 16])

In [22]:
omega_23 = query_2.dot(keys[2])
omega_23

tensor(70.6821)

In [23]:
omega_2 = query_2.matmul(keys.T)
omega_2

tensor([ -2.9915,  19.0139,  70.6821, -32.9331,  -9.6415,  20.1790,  23.0588,
        -12.4027])

In [24]:
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
attention_weights_2

tensor([1.0023e-08, 2.4558e-06, 9.9999e-01, 5.6249e-12, 1.9009e-09, 3.2862e-06,
        6.7510e-06, 9.5320e-10])

In [25]:
context_vector_2 = attention_weights_2.matmul(values)
context_vector_2.shape

torch.Size([16])

##Encoding context embedding via multi-head attention

In [26]:
torch.manual_seed(42)
d = embedded_sentence.shape[1]
one_U_query = torch.rand(d, d)

In [27]:
h = 8
multihead_U_query = torch.rand(h, d, d)
multihead_U_key = torch.rand(h, d, d)
multihead_U_value = torch.rand(h, d, d)

In [28]:
multihead_U_query.shape

torch.Size([8, 16, 16])

In [29]:
multihead_query_2 = multihead_U_query.matmul(x_2)
multihead_query_2.shape

torch.Size([8, 16])

In [30]:
multihead_key_2 = multihead_U_key.matmul(x_2)
multihead_value_2 = multihead_U_value.matmul(x_2)
multihead_key_2[2] #the key vector of the second input element via the third attention head

tensor([2.0321, 3.2833, 1.5345, 2.3700, 1.1367, 1.2881, 1.6517, 0.1471, 0.7686,
        0.6164, 1.2495, 0.9765, 1.7931, 0.3502, 2.4158, 2.2162])

In [31]:
stacked_inputs = embedded_sentence.T.repeat(8, 1, 1)
stacked_inputs.shape

torch.Size([8, 16, 8])

In [32]:
multihead_keys = torch.bmm(multihead_U_key, stacked_inputs)
multihead_keys.shape

torch.Size([8, 16, 8])

In [33]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_keys.shape

torch.Size([8, 8, 16])

In [34]:
multihead_keys[2, 1]

tensor([2.0321, 3.2833, 1.5345, 2.3700, 1.1367, 1.2881, 1.6517, 0.1471, 0.7686,
        0.6164, 1.2495, 0.9765, 1.7931, 0.3502, 2.4158, 2.2162])

In [35]:
multihead_values = torch.matmul(multihead_U_value, stacked_inputs)
multihead_values = multihead_values.permute(0, 2, 1)
multihead_values.shape

torch.Size([8, 8, 16])

In [36]:
multihead_z_2 = torch.rand(8, 16)

In [43]:
omega_multi_2 = torch.bmm(multihead_query_2.view(8, 1, 16), multihead_keys.permute(0, 2, 1))

In [40]:
omega_multi_2 = omega_multi_2.reshape(8, 8)

In [44]:
omega_multi_2.shape

torch.Size([8, 1, 8])

In [46]:
multi_attention_weights_2 = F.softmax(omega_multi_2 / d**0.5, dim=0)
multi_attention_weights_2

tensor([[[4.2392e-01, 1.2445e-01, 3.8245e-05, 1.9425e-01, 8.1041e-04,
          1.7275e-01, 1.9889e-02, 3.1984e-02]],

        [[3.7139e-01, 1.4377e-01, 3.5716e-04, 4.2516e-01, 2.5854e-03,
          4.5327e-01, 1.9640e-02, 3.2628e-02]],

        [[3.6528e-03, 1.0176e-01, 7.2263e-01, 6.3381e-03, 9.0329e-02,
          1.3697e-02, 5.4381e-03, 1.5120e-02]],

        [[2.3187e-02, 8.6181e-02, 1.9426e-01, 2.1601e-02, 1.4555e-02,
          9.4579e-02, 7.3533e-01, 1.1744e-02]],

        [[1.3422e-03, 4.1996e-01, 2.7147e-02, 1.5609e-03, 1.6484e-01,
          1.0101e-01, 1.0680e-03, 4.4225e-03]],

        [[6.3048e-02, 2.3763e-02, 5.4945e-02, 4.6431e-04, 8.9770e-02,
          1.2336e-01, 1.4823e-01, 1.7128e-03]],

        [[9.5721e-03, 7.8261e-02, 4.8254e-04, 2.1447e-01, 4.7506e-03,
          3.2364e-02, 6.9969e-02, 8.8104e-01]],

        [[1.0389e-01, 2.1861e-02, 1.3763e-04, 1.3615e-01, 6.3236e-01,
          8.9557e-03, 4.4134e-04, 2.1353e-02]]])

In [48]:
multihead_z_2 = torch.bmm(multi_attention_weights_2, multihead_values)
multihead_z_2 = multihead_z_2.reshape(8, 16)

In [49]:
linear = torch.nn.Linear(8*16, 16)
context_vector_2 = linear(multihead_z_2.flatten())
context_vector_2.shape

torch.Size([16])

##Decoder and masked multihead attention