## Multi Head Attention

### Why scaling is important before softmax

* We scale by sqrt(key dimension) so that the variance does not blow up.

In [2]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
# import random

In attention mechanism, when the dot product between query and key vectors becomes too large (similar to the example of multipl;ying by 8), the attention scores can become very large. This results in a very sharp softmax distribution, making the model overly confident in one particular "key". Such sharp distributions can make learning unstable.


In [3]:
# define the tensor
tensor = torch.tensor([0.1, -0.2, 0.3, -0.4, -0.3])

# apply softmax without scaling
softmax_result = torch.softmax(tensor, dim =-1)
print("Softmax without scaling:", softmax_result)

# apply softmax with scaling
scaled_tensor = tensor*8
softmax_scaled_result = torch.softmax(scaled_tensor, dim = -1)
print("Softmax with scaling:", softmax_scaled_result)

Softmax without scaling: tensor([0.2359, 0.1748, 0.2881, 0.1431, 0.1581])
Softmax with scaling: tensor([0.1639, 0.0149, 0.8116, 0.0030, 0.0067])


* This is not good as the weights are not proportionally distributed. This can be an issue with the attention mechanism.
* **Thus, we need to have normalisation before applying softmax on the tensor to have better proportion of weights to add upto 1.**
* **Normalization** is done by dividing the dot product of query matrix and key matrix (transpose). We get the attention scores.
* Then we convert the **attention scores** to **attention weights**
* We need to make the variance of the dot product stable

We divide by the sqrt(dimension) to have more stability while learning and also have a stable variance.


In [4]:
# Function to compute variance before and after scaling

def compute_variance(dim, trials = 1000):
    dot_products = []
    scaled_dot_products = []

    # Generate multiple random vectors and compute dot products
    for _ in range(trials):
        q = np.random.randn(dim)
        k = np.random.randn(dim)

        # getting dot product
        dot_product = np.dot(q,k)
        dot_products.append(dot_product)

        # scale the dot product by sqrt
        scaled_dot_product = dot_product/np.sqrt(dim)
        scaled_dot_products.append(scaled_dot_product)
    
    variance_before_scaling = np.var(dot_products)
    variance_after_scaling = np.var(scaled_dot_products)

    return variance_before_scaling, variance_after_scaling

variance_before_5, variance_after_5 = compute_variance(5)
print(f"Variance before scaling (dim=5): {variance_before_5}")
print(f"variance after scaling (dim=5): {variance_after_5}")

variance_before_100, variance_after_100 = compute_variance(100)
print(f"Variance before scaling (dim=100): {variance_before_100}")
print(f"variance after scaling (dim=100): {variance_after_100}")

Variance before scaling (dim=5): 4.971412062399333
variance after scaling (dim=5): 0.9942824124798666
Variance before scaling (dim=100): 97.12069181019781
variance after scaling (dim=100): 0.9712069181019781


#### Implementing a compact self attention python class

In [6]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],
     [0.55, 0.87, 0.66],
     [0.57, 0.85, 0.64],
     [0.22, 0.58, 0.33],
     [0.77, 0.25, 0.10],
     [0.05, 0.80, 0.55]]
)

d_in = 3
d_out = 2

In [7]:
class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] **0.5, dim = -1
        )

        context_vec = attn_weights @ values
        return context_vec
    
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Since the input contains 6 vectors we get a matrix storing the 6 context vectors. To improve the SelfAttention_v1 implementation, we can use PyTorch's nn.Linear layers. This effectively performs the matrix multiplication when the bias units are disabled.

nn.Linear layer has an optimised weight intialisation scheme, contributing to more stable and effective model training.

#### Adding a self attention layer and looking at the flow of the matrix multiplication between query, key and value

In [None]:
class SelfAttention_v2(nn.Module):
    
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x) 

        attn_scores = queries @ keys.T
        attn_weight = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = -1)

        context_vec = attn_weight @ values
        return context_vec
        
sa_v2 = SelfAttention_v2(d_in, d_out)

print(sa_v2(inputs))


tensor([[-0.1370,  0.2535],
        [-0.1289,  0.2542],
        [-0.1293,  0.2541],
        [-0.1278,  0.2542],
        [-0.1365,  0.2535],
        [-0.1240,  0.2545]], grad_fn=<MmBackward0>)


Both SelfAttention_v1 and SelfAttention_v2 give different outputs as they use different intial weights for the weight matrices since nn.Linear uses a more sophisticated weight intialization scheme.