## Implementing Self attention with Trainable weights

In [26]:
import torch

In [27]:
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89], # Your (x^`1)
        [0.55, 0.87, 0.66], # journey
        [0.57, 0.85, 0.64], # starts
        [0.22, 0.58, 0.33], # with
        [0.77, 0.25, 0.10], # one
        [0.05, 0.80, 0.55]  # step
    ]
)

In [28]:
inputs.shape

torch.Size([6, 3])

In [29]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [30]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [31]:
print(W_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


In [32]:
print(W_key)

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])


In [33]:
print(W_value)

Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


In [34]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


In [35]:
keys = inputs @ W_key
values = inputs @ W_value
queries = inputs @ W_query

print(keys.shape)
print(values.shape)
print(queries.shape)  

torch.Size([6, 2])
torch.Size([6, 2])
torch.Size([6, 2])


In [36]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
attn_score_22

tensor(1.8524)

In [37]:
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [38]:
attn_scores = queries @ keys.T
attn_scores

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])

In [39]:
d_k = keys.shape[-1]
d_k

2

In [40]:
attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim = -1)
attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])