In [None]:
#  Has trainable weight matrices to produce good context vectors
# Includes 3 trainable weight matrices. Wk Wq Wv
#  First doing self attention for a single input
import torch
inputs = torch.tensor([
    [0.43, 0.15, 0.89],  # "Your"
    [0.55, 0.87, 0.66],  # "journey"
    [0.57, 0.85, 0.64],  # "starts"
    [0.22, 0.58, 0.33],  # "with"
    [0.77, 0.25, 0.10],  # "one"
    [0.05, 0.80, 0.55]   # "step"
])

x_2 = inputs[1]
d_in = inputs.shape[1] # Dim size
d_out = 2 # Getting different input and output embeddings size for to better follow the computation

# Initializing Wq, Wk, Wv matrices
torch.manual_seed(123)
# nn.parameter vs nn.Linear
# nn.parameter is more low level. We create the weight matrix (tensor) and add it as the input. nn.Parameter lets Pytorch
# know that this weight matrix is learnable if requires_grad=False. After that, we can choose what operation we want to 
# do with the matrix

# nn.linear is a full connected neural network layer. We just provide in_feature dims and out_feature dims and
# nn.linear automatically calculates x . wT + b (in case bias is not false)
W_query = torch.nn.Parameter(torch.rand(d_in,d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# Computing query, key, and value vectors
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

# Computing keys and values for all vectors since we need this information in computing the attention weights
# respect to the query q2
keys = inputs @ W_key
values = inputs @ W_value

print("keys.shape: ", keys.shape)
print("values,shape: ", values.shape)

# Computing attention score w22
keys_2 = keys[1]
# In scaled dot product attn, attn score is done by dot product of keys matrix with the query matrix which we get by trasnforming inputs with the respective weight matrices.
# In basic self attn, attn score was done by dot product of input matrix with the input query
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

# Generalizing
print("query shape:", query_2)
print("keys shape:",keys)
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

# Calculating attention weights
d_k = keys.shape[-1]
# We scale the dot product by dividing with square root of embed dim hence scaled dot product attn
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

# Calculating context vector
# In basic self attn, context vector equals input tensors x attn_Weight_2 ( weighted sum over input vectors )
# In scaled dot product attn, context vector equals weighted values x attn_weight (weighted sum over value vectors )
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)


tensor([0.4306, 1.4551])
keys.shape:  torch.Size([6, 2])
values,shape:  torch.Size([6, 2])
tensor(1.8524)
query shape: tensor([0.4306, 1.4551])
keys shape: tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
tensor([0.3061, 0.8210])
