Transformers

In [1]:
import torch


# input sequence / sentence:
#  "Can you help me to translate this sentence"

sentence = torch.tensor(
    [0, # can
     7, # you     
     1, # help
     2, # me
     5, # to
     6, # translate
     4, # this
     3] # sentence
)

sentence

tensor([0, 7, 1, 2, 5, 6, 4, 3])

In [9]:
torch.manual_seed(123)
embed = torch.nn.Embedding(10, 16)
embedded_sentence = embed(sentence).detach()
embedded_sentence.shape

torch.Size([8, 16])

In [3]:
omega = torch.empty(8, 8)

for i, x_i in enumerate(embedded_sentence):
    for j, x_j in enumerate(embedded_sentence):
        omega[i, j] = torch.dot(x_i, x_j)

In [4]:
omega_mat = embedded_sentence.matmul(embedded_sentence.T)

In [5]:
torch.allclose(omega_mat, omega)

True

In [6]:
import torch.nn.functional as F

attention_weights = F.softmax(omega, dim=1)
attention_weights.shape

torch.Size([8, 8])

In [7]:
attention_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [10]:
x_2 = embedded_sentence[1, :]
context_vec_2 = torch.zeros(x_2.shape)
for j in range(8):
    x_j = embedded_sentence[j, :]
    context_vec_2 += attention_weights[1, j] * x_j
    
context_vec_2

tensor([-9.3975e-01, -4.6856e-01,  1.0311e+00, -2.8192e-01,  4.9373e-01,
        -1.2896e-02, -2.7327e-01, -7.6358e-01,  1.3958e+00, -9.9543e-01,
        -7.1287e-04,  1.2449e+00, -7.8077e-02,  1.2765e+00, -1.4589e+00,
        -2.1601e+00])

In [11]:
context_vectors = torch.matmul(
        attention_weights, embedded_sentence)


torch.allclose(context_vec_2, context_vectors[1])

True

In [12]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]
U_query = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)


In [13]:
x_2 = embedded_sentence[1]
query_2 = U_query.matmul(x_2)


In [14]:
key_2 = U_key.matmul(x_2)
value_2 = U_value.matmul(x_2)


In [15]:

keys = U_key.matmul(embedded_sentence.T).T
torch.allclose(key_2, keys[1])

True

In [20]:
print(torch.tensor([[2, 3, 5],[4, 5, 6]]).T)

tensor([[2, 4],
        [3, 5],
        [5, 6]])


In [21]:
omega_23 = query_2.dot(keys[2])
omega_23

tensor(14.3667)

In [22]:
omega_2 = query_2.matmul(keys.T)
omega_2

tensor([-25.1623,   9.3602,  14.3667,  32.1482,  53.8976,  46.6626,  -1.2131,
        -32.9391])

In [26]:
query_2

tensor([-1.2403, -2.9754, -0.2894, -0.4004, -2.9577, -0.2939, -0.2266, -3.6482,
        -2.6450, -0.9536, -1.1116,  1.1717, -2.2671, -0.7874, -2.0140, -1.6652])

In [27]:
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
attention_weights_2

tensor([2.2317e-09, 1.2499e-05, 4.3696e-05, 3.7242e-03, 8.5596e-01, 1.4026e-01,
        8.8897e-07, 3.1936e-10])