<a href="https://colab.research.google.com/github/satani99/attention_is_all_you_need/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
sentence = torch.tensor(
    [0,
     7,
     1,
     2,
     5,
     6,
     4,
     3]
)
sentence

tensor([0, 7, 1, 2, 5, 6, 4, 3])

In [2]:
torch.manual_seed(42)
embed = torch.nn.Embedding(10, 16)
embedded_sentence = embed(sentence).detach()
embedded_sentence.shape

torch.Size([8, 16])

In [3]:
omega = torch.empty(8, 8)
for i, x_i in enumerate(embedded_sentence):
  for j, x_j in enumerate(embedded_sentence):
    omega[i, j] = torch.dot(x_i, x_j)

omega_mat = embedded_sentence.matmul(embedded_sentence.T)

In [4]:
torch.allclose(omega_mat, omega)

True

In [5]:
omega.shape

torch.Size([8, 8])

In [6]:
import torch.nn.functional as F 
attention_weights = F.softmax(omega, dim=1)
attention_weights.shape

torch.Size([8, 8])

In [7]:
attention_weights

tensor([[1.0000e+00, 8.8716e-14, 4.5282e-12, 1.9794e-11, 1.7691e-11, 3.4558e-12,
         1.1407e-08, 3.7097e-12],
        [1.5851e-06, 8.9771e-01, 1.0376e-03, 2.9957e-03, 1.2276e-03, 8.4663e-02,
         7.9979e-05, 1.2282e-02],
        [2.1346e-08, 2.7376e-07, 9.9999e-01, 3.4414e-09, 5.4635e-06, 1.4109e-08,
         8.3541e-06, 1.9875e-07],
        [1.1455e-09, 9.7027e-09, 4.2247e-11, 9.9995e-01, 3.4450e-10, 3.9137e-06,
         7.0079e-10, 4.4827e-05],
        [1.1860e-06, 4.6064e-06, 7.7701e-05, 3.9909e-07, 9.9992e-01, 8.8794e-07,
         1.1313e-09, 1.5174e-07],
        [2.8595e-07, 3.9208e-04, 2.4766e-07, 5.5960e-03, 1.0959e-06, 9.9353e-01,
         3.7887e-07, 4.7552e-04],
        [2.9961e-06, 1.1757e-09, 4.6546e-07, 3.1805e-09, 4.4321e-12, 1.2026e-09,
         1.0000e+00, 8.1439e-09],
        [5.3149e-08, 9.8482e-06, 6.0406e-07, 1.1098e-02, 3.2428e-08, 8.2335e-05,
         4.4424e-07, 9.8881e-01]])

In [8]:
attention_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [9]:
x_2 = embedded_sentence[1, :]

In [10]:
context_vec_2 = torch.zeros(x_2.shape)

In [11]:
for j in range(8):
  x_j = embedded_sentence[j, :]
  context_vec_2 += attention_weights[1, j] * x_j

In [12]:
context_vec_2

tensor([-1.5340, -0.0618, -0.4731,  0.4421,  0.7010,  0.1463, -0.2659,  0.4264,
        -0.0467,  0.2842,  0.1492,  0.7098,  0.9981,  0.4409,  0.6292,  0.3606])

In [13]:
context_vector = torch.matmul(attention_weights, embedded_sentence)

In [14]:
context_vector.shape

torch.Size([8, 16])

In [15]:
torch.allclose(context_vec_2, context_vector[1])

True

##Scaled Dot-product attention

In [21]:
torch.manual_seed(42)
d = embedded_sentence.shape[1]
U_query = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)

In [22]:
x_2 = embedded_sentence[1]
query_2 = U_query.matmul(x_2)

In [23]:
key_2 = U_key.matmul(x_2)
value_2 = U_value.matmul(x_2)

In [30]:
query_2

tensor([ 1.9594,  0.9311,  1.9671,  0.8957,  0.0800,  1.5594, -0.1917,  0.2971,
         0.3863,  1.8167,  0.8888,  1.2970,  0.6111,  0.9745,  1.6844,  0.5390])

In [31]:
keys = U_key.matmul(embedded_sentence.T).T
values = U_value.matmul(embedded_sentence.T).T

In [32]:
keys.shape

torch.Size([8, 16])

In [33]:
omega_23 = query_2.dot(keys[2])
omega_23

tensor(70.6820)

In [34]:
omega_2 = query_2.matmul(keys.T)
omega_2

tensor([ -2.9915,  19.0139,  70.6820, -32.9331,  -9.6415,  20.1790,  23.0588,
        -12.4027])

In [38]:
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
attention_weights_2

tensor([1.0023e-08, 2.4558e-06, 9.9999e-01, 5.6250e-12, 1.9010e-09, 3.2862e-06,
        6.7510e-06, 9.5321e-10])

In [42]:
context_vector_2 = attention_weights_2.matmul(values)
context_vector_2

tensor([4.9591, 5.3071, 5.1345, 1.8415, 4.8305, 3.2982, 5.7193, 5.1012, 4.5425,
        5.6862, 5.7829, 6.0999, 6.0630, 3.1988, 4.1003, 3.3661])