In [1]:
import torch

sentence = torch.tensor([0, 7, 1, 2, 5, 6, 4, 3])

sentence

tensor([0, 7, 1, 2, 5, 6, 4, 3])

In [10]:
torch.manual_seed(123)
embed = torch.nn.Embedding(10, 16)
embedded_sentence = embed(sentence).detach()
embedded_sentence.shape

torch.Size([8, 16])

In [12]:
omega = torch.empty(8, 8)

for i, x_i in enumerate(embedded_sentence):
    for j, x_j in enumerate(embedded_sentence):
        omega[i, j] = torch.dot(x_i, x_j)

In [13]:
omega_mat = embedded_sentence.matmul(embedded_sentence.T)
torch.allclose(omega_mat, omega)

True

In [14]:
import torch.nn.functional as F

attention_weights = F.softmax(omega, dim=1)
attention_weights.shape

torch.Size([8, 8])

In [15]:
attention_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [17]:
x_2 = embedded_sentence[1, :]
context_vec_2 = torch.zeros(x_2.shape)
for j in range(8):
    x_j = embedded_sentence[j, :]
    context_vec_2 += attention_weights[1, j] * x_j

context_vec_2

tensor([-9.3975e-01, -4.6856e-01,  1.0311e+00, -2.8192e-01,  4.9373e-01,
        -1.2896e-02, -2.7327e-01, -7.6358e-01,  1.3958e+00, -9.9543e-01,
        -7.1287e-04,  1.2449e+00, -7.8077e-02,  1.2765e+00, -1.4589e+00,
        -2.1601e+00])

In [18]:
context_vectors = torch.matmul(attention_weights, embedded_sentence)

torch.allclose(context_vec_2, context_vectors[1])

True

In [None]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]
U_querry = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)

In [24]:
x_2 = embedded_sentence[1]
query_2 = x_2.matmul(U_querry)
key_2 = x_2.matmul(U_key)
value_2 = x_2.matmul(U_value)

In [25]:
queries = embedded_sentence.matmul(U_querry)
torch.allclose(queries[1], query_2)

True

In [26]:
keys = embedded_sentence.matmul(U_key)
torch.allclose(keys[1], key_2)

True

In [27]:
values = embedded_sentence.matmul(U_value)
torch.allclose(values[1], value_2)

True

In [28]:
omega_23 = query_2.dot(keys[2])
omega_23

tensor(19.1885)

In [29]:
omega_2 = query_2.matmul(keys.T)
omega_2

tensor([ -0.3244,  29.8293,  19.1885,  30.8173,  46.4730,  40.5626,   4.0052,
        -40.5871])

In [30]:
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
attention_weights_2

tensor([6.5614e-06, 1.2328e-02, 8.6214e-04, 1.5782e-02, 7.9060e-01, 1.8040e-01,
        1.9367e-05, 2.7895e-10])

In [69]:
context_vector_2 = torch.zeros((16))
for j in range(8):
    context_vector_2 += attention_weights_2[j] * values[j, :]

context_vector_2

tensor([-2.3462, -3.4054, -2.2193, -1.9807, -2.2367, -3.0659, -2.5027, -1.9063,
        -1.6545, -2.5260, -2.9716, -1.8847, -2.0280, -3.2097, -2.5135, -2.2080])

In [31]:
context_vector_2 = attention_weights_2.matmul(values)
context_vector_2

tensor([-2.3462, -3.4054, -2.2193, -1.9807, -2.2367, -3.0659, -2.5027, -1.9063,
        -1.6545, -2.5260, -2.9716, -1.8847, -2.0280, -3.2097, -2.5135, -2.2080])

In [70]:
values.shape

torch.Size([8, 16])

In [68]:
attention_weights_2.shape, values.shape

(torch.Size([8]), torch.Size([8, 16]))

In [71]:
z = attention_weights.matmul(values)
z.shape

torch.Size([8, 16])

In [78]:
sentence = torch.tensor([0, 7, 1, 2, 5, 6, 4, 3, 7, 6])
embedded_sentence = embed(sentence).detach()

torch.manual_seed(123)

d = embedded_sentence.shape[1]

In [79]:
h, d_k, d_v = 8, 14, 15
multihead_U_query = torch.rand(h, d, d_k)
multihead_U_key = torch.rand(h, d, d_k)
multihead_U_value = torch.rand(h, d, d_v)

In [80]:
multihead_query_2 = x_2.matmul(multihead_U_query)
multihead_query_2.shape

torch.Size([8, 14])

In [81]:
multihead_queries = embedded_sentence.matmul(multihead_U_query)
multihead_queries.shape

torch.Size([8, 10, 14])

In [82]:
multihead_keys = embedded_sentence.matmul(multihead_U_key)
multihead_keys.shape

torch.Size([8, 10, 14])

In [83]:
multihead_values = embedded_sentence.matmul(multihead_U_value)
multihead_values.shape

torch.Size([8, 10, 15])

In [84]:
multihead_omega = multihead_queries.matmul(multihead_keys.permute(0, 2, 1))
multihead_omega.shape

torch.Size([8, 10, 10])

In [85]:
multihead_Alpha = F.softmax(multihead_omega, dim=2)
multihead_Alpha.sum(dim=2)

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000]])

In [86]:
multihead_Alpha.shape, multihead_values.shape

(torch.Size([8, 10, 10]), torch.Size([8, 10, 15]))

In [95]:
W_O = torch.rand(h*d_v, d)
b_O = torch.rand(d)

z = multihead_Alpha.matmul(multihead_values)
z_2 = z.permute(1, 0, 2)[1].flatten()
z_2_output = z_2 @ W_O + b_O
z_2_output.shape

torch.Size([16])