<a href="https://colab.research.google.com/github/satani99/attention_is_all_you_need/blob/main/self_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
sentence = "Life is short, eat dessert first"

dc = {s:i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


In [2]:
import torch 

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


In [3]:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])
torch.Size([6, 16])


In [4]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 24, 24, 28 

W_query = torch.rand(d_q, d)
W_key = torch.rand(d_k, d)
W_value = torch.rand(d_v, d)

In [5]:
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


In [6]:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print("keys.shape", keys.shape)
print("values.shape", values.shape)

keys.shape torch.Size([6, 24])
values.shape torch.Size([6, 28])


In [7]:
# compute the unnormalized attention weight for the query 
# and the 5th input 
omega_24 = query_2.dot(keys[4])
print(omega_24)

tensor(11.1466)


In [8]:
omega_2 = query_2.matmul(keys.T)
print(omega_2)

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800])


In [9]:
# Computing the attention weights 
import torch.nn.functional as F 

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458])


In [10]:
context_vector_2 = attention_weights_2.matmul(values)

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084])


In [11]:
# Multi-Head Attention with example of 3 attention heads

h = 3 
multihead_W_query = torch.rand(h, d_q, d)
multihead_W_key = torch.rand(h, d_k, d)
multihead_W_value = torch.rand(h, d_v, d)
print(multihead_W_query.shape)

torch.Size([3, 24, 16])


In [12]:
multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2.shape)

torch.Size([3, 24])


In [13]:
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)

In [14]:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs.shape)

torch.Size([3, 16, 6])


In [27]:
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 24, 6])
multihead_values.shape: torch.Size([3, 28, 6])


In [28]:
# multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 24, 6])
multihead_values.shape: torch.Size([3, 6, 28])


In [29]:
omega_multihead_2 = torch.bmm(multihead_query_2.view(3, 1, 24), multihead_keys)
print(omega_multihead_2)
print(omega_multihead_2.shape)

tensor([[[ -0.9446,  -9.9220,  24.6713,   9.3329,  18.5821, -18.7331]],

        [[  2.6489,  -8.1238,  -6.9092,  -9.2760,  -5.8502,   1.2093]],

        [[  0.3535,   6.1598,  -1.7309,  -0.6793,  -1.9436, -13.1709]]])
torch.Size([3, 1, 6])


In [30]:
# Computing the attention weights 
import torch.nn.functional as F 

multihead_attention_weights_2 = F.softmax(omega_multihead_2 / d_k**0.5, dim=0)
print(multihead_attention_weights_2)

tensor([[[0.2280, 0.0344, 0.9939, 0.8681, 0.9785, 0.0159]],

        [[0.4748, 0.0496, 0.0016, 0.0194, 0.0067, 0.9344]],

        [[0.2972, 0.9160, 0.0045, 0.1125, 0.0148, 0.0496]]])


In [33]:
multihead_context_vector_2 = multihead_attention_weights_2.matmul(multihead_values).view(3, 28)

print(multihead_context_vector_2.shape)
print(multihead_context_vector_2)

torch.Size([3, 28])
tensor([[-4.0196, -2.8839, -2.7928, -4.8711, -5.3576, -5.6695, -7.1135,  0.4474,
         -6.0364, -2.5727, -6.9008, -6.2505, -2.4347, -6.2277, -5.0631, -6.3769,
         -5.6076, -0.2504, -7.5703, -2.8492, -2.0196, -3.0125, -3.3246, -2.6591,
         -3.9002, -3.3792, -7.1753, -1.9207],
        [ 1.8541,  1.3980,  1.3991,  1.9406,  4.3007,  3.2203,  2.7824,  2.3874,
          1.9616,  3.5196,  1.0982,  2.4626,  2.7614,  0.8944,  4.2907,  2.2192,
          1.4597,  3.4831,  3.7664,  4.9082,  1.3324,  3.0179,  1.8540,  3.0721,
          4.9994,  2.8471,  3.5666,  1.9237],
        [-1.3329,  1.2958,  1.3449, -1.0297,  0.7839, -0.4812, -0.6817, -1.3890,
         -1.6334,  1.5294, -1.2365, -0.2269, -0.7305, -0.6153, -0.3127, -1.6055,
         -0.1556,  0.2809,  1.8245, -0.2797, -0.6460,  1.2682,  0.9108,  1.2385,
         -0.3255,  0.7398,  0.0181, -0.0865]])
