# Self attention with trainable weights

In [None]:
import torch

# Just for illustration, let's use small embedding dimension (3-dimensional vector)
inputs = torch.tensor(
    [[0.43,0.15,0.89], # Your (x^1)
     [0.55,0.87,0.66], # journey (x^2)
     [0.57,0.85,0.64], # starts (x^3)
     [0.22,0.58,0.33], # with (x^4)
     [0.77,0.25,0.10], # one (x^6)
     [0.05,0.80,0.55]] # step (x^7)
    )

In [None]:
# defining dimensions for Key,Query,Value vectors
x_2 = inputs[1] # Journey
d_in = inputs.shape[1]
d_out = 2

Key, Query, Value

In [None]:
torch.manual_seed(123)
# for now let's set gradient as false, but we will set it to true during training.
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [None]:
print(W_query)
print(W_key)
print(W_value)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])
Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


In [None]:
# for journey
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])
tensor([0.3951, 1.0037])


In [None]:
# let's obtain Key,value,query for all
queries = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value
print("Keys shape:", keys.shape)
print("Queries shape:", queries.shape)
print("Values shape:", values.shape)

Keys shape: torch.Size([6, 2])
Queries shape: torch.Size([6, 2])
Values shape: torch.Size([6, 2])


In [None]:
print(queries)
print(keys)
print(values)

tensor([[0.2309, 1.0966],
        [0.4306, 1.4551],
        [0.4300, 1.4343],
        [0.2355, 0.7990],
        [0.2983, 0.6565],
        [0.2568, 1.0533]])
tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])
tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])


Attention score for "Journey"

In [None]:
# each query with all keys transpose. We want 6 attention scores for each query so we take transpose of keys.
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


Attention scores for all.

In [None]:
attn_scores = queries @ keys.T
print(attn_scores)

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])


Now, it's time for normalization

Scale dot product attention (Scale by square root of key dimension and for our case it is 2)
and apply softmax

In [None]:
d_k = keys.shape[-1]
attn_weights = torch.softmax(attn_scores / (d_k**0.5), dim=-1)
print(attn_weights)

tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])


In [None]:
# context vector for journey
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3069, 0.8188])


In [None]:
# for all
context_vector = attn_weights @ values
print(context_vector)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


Implementing a self attention python class

In [None]:
import torch.nn as nn

class SelfAttention_V1(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_query = nn.Parameter(torch.rand(d_in, d_out))
    self.W_key = nn.Parameter(torch.rand(d_in, d_out))
    self.W_value = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
    queries = x @ self.W_query
    keys = x @ self.W_key
    values = x @ self.W_value

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
    context_vector = attn_weights @ values

    return context_vector

In [None]:
torch.manual_seed(123)
sa_v1 = SelfAttention_V1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


We can improve v1 by utilizing nn.Linear of pytorch
It has optimized matrix multiplication contributing to more stable and effective model training.

In [None]:
class SelfAttention_V2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias=False):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

  def forward(self, x):
    queries = self.W_query(x)
    keys = self.W_key(x)
    values = self.W_value(x)

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
    context_vector = attn_weights @ values

    return context_vector

In [None]:
torch.manual_seed(789)
sa_v2 = SelfAttention_V2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
