In simplified attention weights, we essentially went ahead and use the fact that the embedding layer already has vectors pointing in the right directions and so we can figure out the "alignment" between the tokens.

But remember that we want to keep "an open mind". That embedded layer is a result of the training that is done, and therefore the tokens' alignments are representative of the type and quality of the training data.

What we really want to do is train our neural net with data specific for our domain or use case. Which means that as we feed more data for our use case during training, we want these attention weights to be trainable too.

Remember that we are talking about the training the attention weights. We are not (yet) talking about updating the embedding layer vectors

To do so, we are going to make use of three weight matrices.
Key, Query, Value

In [17]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

d_in = inputs.shape[1] ## the input embedding size
d_out = 2 ## the output embedding size
inputs.shape

torch.Size([6, 3])

In [4]:
torch.manual_seed(123)

#Think of torch.nn.Parameter as a multidimensional matrix, for now.

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)


In [13]:
#project the embedded vectors over these vectors. Remember that these random but the whole idea is that we can
#train them during training.

keys = inputs@W_key
values=inputs@W_value
queries = inputs@W_query

#embedding dimension of the keys
#remember that the first dimension is always the number of tokens. 6 for example here.
#the second dimension is the number of dimensions in the keys that we have projected
#the vectors over, i.e., we've embedded them over this dimension of the keys vectors.

d_k = keys.shape[-1]

In [14]:
attn_scores = queries@keys.T

In [16]:
#we are going to normalize but this time around, we are going to use 
#dot product scaling.
#scaling by the square root of the embedding dimension is also why this self-attension
#mechanism is called the scaled-dot product attention
attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
attn_weights

tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])

In [19]:
context_vectors = attn_weights @ values

context_vectors

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])