# Trainable Self-attention

Inputs

In [1]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

x_2 = inputs[1] # second input element

d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2

## Q/K/V

Initilize Q-K-V Weights

In [3]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [16]:
query_2 = x_2 @ W_query 

print("query_2:", query_2)

query_2: tensor([0.4306, 1.4551])


In [5]:
keys = inputs @ W_key 
values = inputs @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


## Attention Score

Elementary calculatin

In [8]:
keys_2 = keys[1] 
attn_score_22 = query_2.dot(keys_2)

print("attn_score_22:", attn_score_22)

attn_score_22: tensor(1.8524)


Attention scores of inputs[2] for all the keys

In [9]:
attn_scores_2 = query_2 @ keys.T # All attention scores for given query

print("All attention scores for given query input[1]: ", attn_scores_2)

All attention scores for given query input[1]:  tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


## Attension Weight

In [10]:
d_k = keys.shape[1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)

print("Attension Weight for inputs[2]: ", attn_weights_2)

Attension Weight for inputs[2]:  tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


## Context 

### Manuel Calculation

In [12]:
context_vec_2 = attn_weights_2 @ values

print("Context Vector for inputs[1]: ", context_vec_2)

Context Vector for inputs[1]:  tensor([0.3061, 0.8210])


### Self-Attention class

In [17]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)

sa_v1 = SelfAttention_v1(d_in, d_out)

print("Context Vector for the whole inputs: \n", sa_v1(inputs))

Context Vector for the whole inputs: 
 tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


### Class V2

In [19]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(789)

sa_v2 = SelfAttention_v2(d_in, d_out)

print("Context Vector for the whole inputs: \n", sa_v2(inputs))

Context Vector for the whole inputs: 
 tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
