## Input sequence: "Dream big and work for it"

In [1]:
import torch

inputs = torch.tensor(
    [[0.72, 0.45, 0.31], # Dream    (x^1)
     [0.75, 0.20, 0.55], # big      (x^2)
     [0.30, 0.80, 0.40], # and      (x^3)
     [0.85, 0.35, 0.60], # work     (x^4)
     [0.55, 0.15, 0.75], # for      (x^5)
     [0.25, 0.20, 0.85]] # it       (x^6)
)

# Corresponding words
words = ['Dream', 'big', 'and', 'work', 'for', 'it']

## We want to generate the context vector for 2nd token

In [2]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2 #C = dimensionality of the context vector
print(x_2)
print(d_in)
print(d_out)

tensor([0.7500, 0.2000, 0.5500])
3
2


## Randomly initializing Wq, Wk, Wv matrices

In [3]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [4]:
print(W_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


In [5]:
print(W_key)

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])


In [6]:
print(W_value)

Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


In [7]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([0.3131, 1.0017])
tensor([0.3126, 0.6001])
tensor([0.1852, 0.6829])


## Calculating Q, K, and V using X, Wq, Wk, Wv

In [8]:
keys = inputs @ W_key
values = inputs @ W_value
queries = inputs @ W_query

print('keys.shape', keys.shape)
print('values.shape', values.shape)
print('queries.shape', queries.shape)

print('keys:', keys)
print('queries', queries)
print('values', values)

keys.shape torch.Size([6, 2])
values.shape torch.Size([6, 2])
queries.shape torch.Size([6, 2])
keys: tensor([[0.2789, 0.6137],
        [0.3126, 0.6001],
        [0.3143, 0.8867],
        [0.3697, 0.7536],
        [0.3392, 0.6807],
        [0.3389, 0.7549]])
queries tensor([[0.3494, 0.9504],
        [0.3131, 1.0017],
        [0.3198, 1.0524],
        [0.3842, 1.2000],
        [0.2561, 1.0373],
        [0.1872, 1.0034]])
values tensor([[0.2336, 0.5789],
        [0.1852, 0.6829],
        [0.3232, 0.7113],
        [0.2462, 0.8042],
        [0.1780, 0.7890],
        [0.1830, 0.8328]])


## Keys corresponding to second token and the attention of second token to itself

In [9]:
keys_2 = keys[1]
attn_score_22 = query_2 @ keys_2
print(attn_score_22)

tensor(0.6990)


## All attention scores for query number 2

In [10]:
attn_scores_2 = query_2 @ keys.T # All attention scores for given query
print(attn_scores_2)

tensor([0.7021, 0.6990, 0.9867, 0.8707, 0.7880, 0.8624])


## Attention scores (NOT Weights) matrix

In [11]:
attn_scores = queries @ keys.T # omega
print(attn_scores)

tensor([[0.6807, 0.6795, 0.9526, 0.8454, 0.7654, 0.8359],
        [0.7021, 0.6990, 0.9867, 0.8707, 0.7880, 0.8624],
        [0.7350, 0.7315, 1.0337, 0.9113, 0.8248, 0.9029],
        [0.8436, 0.8402, 1.1848, 1.0464, 0.9471, 1.0361],
        [0.7080, 0.7025, 1.0003, 0.8764, 0.7929, 0.8699],
        [0.6680, 0.6606, 0.9486, 0.8254, 0.7465, 0.8210]])


## Scale by 1/sqrt(d) and then take softmax

In [12]:
attn_scores_2 = query_2 @ keys.T # All attention scores for given query
print(attn_scores_2)
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
print(attn_weights)

tensor([0.7021, 0.6990, 0.9867, 0.8707, 0.7880, 0.8624])
tensor([0.1531, 0.1528, 0.1873, 0.1725, 0.1627, 0.1715])
tensor([[0.1536, 0.1534, 0.1861, 0.1725, 0.1630, 0.1714],
        [0.1531, 0.1528, 0.1873, 0.1725, 0.1627, 0.1715],
        [0.1525, 0.1521, 0.1884, 0.1728, 0.1625, 0.1717],
        [0.1505, 0.1501, 0.1915, 0.1737, 0.1619, 0.1724],
        [0.1530, 0.1524, 0.1881, 0.1724, 0.1625, 0.1716],
        [0.1538, 0.1530, 0.1875, 0.1719, 0.1625, 0.1713]])


## Softmax peaks when the numbers are scaled

In [13]:
import torch
# Define the tensor
tensor = torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])

# Apply softmax
# Apply softmax without scaling
softmax_result = torch.softmax(tensor, dim=-1)
print("Softmax without scaling:", softmax_result)

# Multiply the tensor by 8 and the apply softmax
scaled_tensor = tensor * 8
softmax_scaled_result = torch.softmax(scaled_tensor, dim=-1)
print("Softmax after scaling (tensor * 8):", softmax_scaled_result)

Softmax without scaling: tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
Softmax after scaling (tensor * 8): tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


## Scaling has to be such that the variance of Q*K.T is close to 1

In [14]:
import numpy as np

# Function to compute variance before and after scaling
def compute_variance(dim, num_trials=1000):
    dot_products = []
    scaled_dot_products = []

    # Generate multiple random vectors and compute dot products
    for _ in range(num_trials):
        q = np.random.randn(dim)
        k = np.random.randn(dim)

        # Compute dot product
        dot_product = q @ k
        dot_products.append(dot_product)

        # Scale the dot product by sqrt (dim)
        scaled_dot_product = dot_product / (dim)
        scaled_dot_products.append(scaled_dot_product)
          
    # Calculate the variance of the dot products
    variance_before_scaling = np.var(dot_products)
    variance_after_scaling = np.var(scaled_dot_products)

    return variance_before_scaling, variance_after_scaling

# For dimension 5
variance_before_5, variance_after_5 = compute_variance(5)
print(f"Variance before scaling (dim=5): {variance_before_5}")
print(f"Variance after scaling (dim=5): {variance_after_5}")

# For dimension 100
variance_before_100, variance_after_100 = compute_variance(100)
print(f"Variance before scaling (dim=100): {variance_before_100}")
print(f"Variance after scaling (dim=100): {variance_after_100}")

Variance before scaling (dim=5): 4.862734596201852
Variance after scaling (dim=5): 0.19450938384807406
Variance before scaling (dim=100): 97.49757474284462
Variance after scaling (dim=100): 0.009749757474284462


## Context vector corresponding to 2nd input token

In [15]:
context_vec_2 = attn_weights_2 @ values
context_vec = attn_weights @ values
print(context_vec_2)
print(context_vec)

tensor([0.2274, 0.7362])
tensor([[0.2273, 0.7361],
        [0.2274, 0.7362],
        [0.2276, 0.7363],
        [0.2280, 0.7368],
        [0.2275, 0.7362],
        [0.2275, 0.7360]])


## Python class for doing this whole operation

In [16]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [17]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2273, 0.7361],
        [0.2274, 0.7362],
        [0.2276, 0.7363],
        [0.2280, 0.7368],
        [0.2275, 0.7362],
        [0.2275, 0.7360]], grad_fn=<MmBackward0>)


In [22]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)  
        queries = self.W_query(x) 
        values = self.W_value(x)  

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)
        context_vec = torch.matmul(attn_weights, values)  # [batch, seq_len, d_out]
        return context_vec

# Example test
torch.manual_seed(123)
sa_v2 = SelfAttention_v1(d_in, d_out)
print(sa_v2(inputs))

tensor([[0.2273, 0.7361],
        [0.2274, 0.7362],
        [0.2276, 0.7363],
        [0.2280, 0.7368],
        [0.2275, 0.7362],
        [0.2275, 0.7360]], grad_fn=<MmBackward0>)
