<a href="https://colab.research.google.com/github/shreyans-sureja/llm-101/blob/main/part8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Self attention with trainable weights

Self attention mechanism also called as "scaled dot-product attention"

1. we want to compute context vectors as weighted sum over the input vectors specific to a certain input element.
2. we will introduce weight matrices that are updated during model training.
3. These trainable weight matrices are crucial so that model can learn to produce good context vectors.
4. We will implement the self attention mechanism step by step by introducing 3 trainable weight matrices: Wq, Wk, Wv (query, key and value)


convert input embeddings into key, query and value vectors.

sentence = your journey starts with one step. \
dimension = 3d \
input is 6x3 matrix.

we will use Wq, Wk, and Wv of 3x2 random value initalized matrices.

**Queries = Inputs * Wq [6x2 matrix]**

**Keys = Inputs * Wk [6x2 matrix]**

**Values = Inputs * Wv [6x2 matrix]**

So every raw here represents one input token.

In [1]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

Note that in GPT-like models, the input and output dimensions are usually the same.

But for illustration purposes, to better follow the computation, we choose different input (d_in=3)
and output (d_out=2) dimensions here.

In [3]:
x_2 = inputs[1] #A
d_in = inputs.shape[1] #B
d_out = 2 #C

In [4]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [5]:
print(W_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


Note that we are setting requires_grad=False to reduce clutter in the outputs for
illustration purposes.

If we were to use the weight matrices for model training, we
would set requires_grad=True to update these matrices during model training.


In [6]:
print(x_2)

tensor([0.5500, 0.8700, 0.6600])


In [7]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

In [8]:
print(query_2)

tensor([0.4306, 1.4551])


In [9]:
keys = inputs @ W_key
values = inputs @ W_value
queries = inputs @ W_query

In [10]:
print(keys)

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])


Next step - compute the attention scores.

we will start by computing only one context vector Z2. for attention score we will use query vector for the token to find how much it get attention from different tokens. To get that we can do dot product b/w query vector for token and key vectors of all other tokens.


Query_2 = [1x2]

keys = [6x2]

queries * keys.transpose = [1x6]

In [11]:
keys_2 = keys[1]
print(keys_2)

tensor([0.4433, 1.1419])


In [13]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [14]:
attn_scores = queries @ keys.T
print(attn_scores)

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])


In [15]:
print(torch.softmax(attn_scores, dim=-1))

tensor([[0.1484, 0.2285, 0.2217, 0.1301, 0.0883, 0.1831],
        [0.1401, 0.2507, 0.2406, 0.1157, 0.0687, 0.1842],
        [0.1406, 0.2496, 0.2397, 0.1164, 0.0696, 0.1841],
        [0.1548, 0.2130, 0.2083, 0.1394, 0.1047, 0.1799],
        [0.1577, 0.2067, 0.2028, 0.1428, 0.1122, 0.1777],
        [0.1494, 0.2267, 0.2202, 0.1310, 0.0901, 0.1825]])


We compute the attention weights by scaling the
attention scores and using the softmax function we used earlier.

The difference to earlier is
that we now scale the attention scores by dividing them by the **square root of the
embedding dimension of the keys.**

Note that taking the square root is mathematically the
same as exponentiating by 0.5

In [19]:
d_k = keys.shape[-1]
print(d_k)

attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
print(attn_weights)

2
tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])


## Why divide by sqrt(dimension)

- softmax function is sensitive to the magnitudes of its inputs. When the inputs are large, the difference between the exponential values of each input become much more pronounced. This causes the softmax output to become "peaky", where the highest value recieves almost all the probablity mass and the rest receives very little.

check example in below cell.

In [20]:
### experimentation cell

tensor = torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])

print("softmax without scalling: ", torch.softmax(tensor, dim=-1))

scaled_tensor = tensor * 8
print("softmax with scalling:", torch.softmax(scaled_tensor, dim=-1))

# see the results, few cells get almost all the weightage.

softmax without scalling:  tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
softmax with scalling: tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


We can see the results in above example, model can become too much confident in one key. Such sharp distribution can make learning unstable.

### But why sqrt?

To make the variance of the dot product scale.

The dot product of Q and K increases the variance because multiplying two random numbers increases the variance.

The increases in variance grows with dimension. Diving by sqrt(dimensio) keeps the variance close to 1.

Example:

In [23]:
import numpy as np

# function to compute variance before and after scaling
def compute_variance(dim, num_trials=1000):
  dot_products = []
  scaled_dot_products = []

  for _ in range(num_trials):
    q = np.random.randn(dim)
    k = np.random.randn(dim)

    dot_product = np.dot(q, k)
    dot_products.append(dot_product)

    scaled_dot_product = dot_product / np.sqrt(dim)
    scaled_dot_products.append(scaled_dot_product)

  variance_before_scalling = np.var(dot_products)
  variance_after_scalling = np.var(scaled_dot_products)

  return variance_before_scalling, variance_after_scalling


# for dimension 5
variance_before_5, variance_after_5 = compute_variance(5)
print(f"variance before scaling for dim=5: {variance_before_5}")
print(f"variance after scaling for dim=5: {variance_after_5}")


# for dimension 20
variance_before_20, variance_after_20 = compute_variance(20)
print(f"variance before scaling for dim=20: {variance_before_20}")
print(f"variance after scaling for dim=20: {variance_after_20}")


# for dimension 100
variance_before_100, variance_after_100 = compute_variance(100)
print(f"variance before scaling for dim=100: {variance_before_100}")
print(f"variance after scaling for dim=100: {variance_after_100}")

variance before scaling for dim=5: 5.656960453679603
variance after scaling for dim=5: 1.1313920907359205
variance before scaling for dim=20: 20.839055222643307
variance after scaling for dim=20: 1.0419527611321655
variance before scaling for dim=100: 95.87793641120774
variance after scaling for dim=100: 0.9587793641120772


Final step: compute the context vectors.

we got attention weights using key and query vectors. using value vectors and this attention weights we can get the final context vectors.



In [25]:
context_vec = attn_weights @ values
print(context_vec)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


## Implementing a compact self attention python class

In [28]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_query = nn.Parameter(torch.rand(d_in, d_out))
    self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
    self.W_value = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
    keys = x @ self.W_key
    queries = x @ self.W_query
    values = x @ self.W_value

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

    context_vec = attn_weights @ values
    return context_vec

In this PyTorch code, SelfAttention_v1 is a class derived from nn.Module, which is a
fundamental building block of PyTorch models, which provides necessary functionalities for
model layer creation and management.    

The __init__ method initializes trainable weight matrices (W_query, W_key, and
W_value) for queries, keys, and values, each transforming the input dimension d_in to an
output dimension d_out.

During the forward pass, using the forward method, we compute the attention scores
(attn_scores) by multiplying queries and keys, normalizing these scores using softmax.

Finally, we create a context vector by weighting the values with these normalized attention


In [29]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


We can improve the SelfAttention_v1 implementation further by utilizing PyTorch's
nn.Linear layers, which effectively perform matrix multiplication when the bias units are
disabled.

Additionally, a significant advantage of using nn.Linear instead of manually
implementing nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight
initialization scheme, contributing to more stable and effective model training.


In [30]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [31]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


Note that SelfAttention_v1 and SelfAttention_v2 give different outputs because they
use different initial weights for the weight matrices since nn.Linear uses a more
sophisticated weight initialization scheme.

## Why do we use the terms key, query and value?

Query = Analogous to search query in a database. It represents the current token the model focus on.

Key = In attention mechanism, each item in input sequence has a key. Keys are used to match with the query.

Value = It represents the actual content or representation of the input items. Once the model determines which keys(which parts of the input) are most relevant to the query(current focus item), it retrieves the corresponding values.