# **Implementing Self-attention with trainable weights**

In [1]:
import torch

In [2]:
words = ['Your', 'journey', 'starts', 'with', 'one', 'step']

inputs = torch.tensor(
  [
   [0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55], # step     (x^6)
  ]
)

In [3]:
inputs.shape

torch.Size([6, 3])

In [4]:
x_2 = inputs[1] # A
d_in = inputs.shape[1] # B
d_out = 2 # C

In [5]:
torch.manual_seed(123)
# the trainable weights we're using here is 3x2. the first rows, 3 has to match the vector dimension of the input vector dimenstion. in our case 3, but the second dimension can be anything
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # 3x2
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [6]:
W_query

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

In [7]:
W_key

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])

In [8]:
W_value

Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])

In [9]:
# computing for only the second in the input x_2
query_2 = x_2 @ W_query # 1x3 @ 3x2 = 1x2
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
query_2, key_2, value_2

(tensor([0.4306, 1.4551]), tensor([0.4433, 1.1419]), tensor([0.3951, 1.0037]))

In [10]:
# computing key value for all
queries = inputs @ W_query # 6x3 @ 3x2 = 6x2
keys = inputs @ W_key # 6x3 @ 3x2 = 6x2
values = inputs @ W_value # 6x3 @ 3x2 = 6x2

queries.shape, keys.shape, values.shape

(torch.Size([6, 2]), torch.Size([6, 2]), torch.Size([6, 2]))

In [11]:
# compute the attention score for w22
key_2 = keys[1]
attn_score_22 = query_2.dot(key_2) 
attn_score_22

tensor(1.8524)

In [12]:
attn_scores_2 = query_2 @ keys.T # 1X2 @ 2X6
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [13]:
attn_scores = queries @ keys.T # 6x2 @ 2x6 -> omega
attn_scores

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])

we compute the attention weight by scaling the attention scores and using the softmax function. we scale the attention scores by dividing them with the square root of the embedding dimention of the key matrixs

**note that taking the sqaure root is mathematically the same as exponentiating by 0.5**

In [14]:
d_k = keys.shape[-1] # dimension of keys matrix
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
attn_weights_2, d_k

(tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820]), 2)

## **WHY DIVIDE BY SQRT (DIMENSION)**

<div class="alert alert-block alert-warning">

Reason 1: For stability in learning

The softmax function is sensitive to the magnitudes of its inputs. When the inputs are large, the differences between the exponential values of each input become much more pronounced. This causes the softmax output to become "peaky," where the highest value receives almost all the probability mass, and the rest receive very little.

In attention mechanisms, particularly in transformers, if the dot products between query and key vectors become too large (like multiplying by 8 in this example), the attention scores can become very large. This results in a very sharp softmax distribution, making the model overly confident in one particular "key." Such sharp distributions can make learning unstable,
    
</div>

In [15]:
# define the tensor
tensor  = torch.tensor([0.1, -0.2,0.3,-0.2,0.5])

# apply softmax without scaling
softmax_result = torch.softmax(tensor, dim=-1)
print(f"Softmax without scaling: {softmax_result}")

# multiply the tensor by 8 and then apply softmax
scaled_tensor = tensor * 8
softmax_scaled_result = torch.softmax(scaled_tensor, dim=-1)
print(f"Softmax after scaling (tensor * 8): {softmax_scaled_result}")


Softmax without scaling: tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
Softmax after scaling (tensor * 8): tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


<div class="alert alert-block alert-warning">

Reason 2: To make the variance of the dot product stable

The dot product of  Q and K increases the variance because multiplying two random numbers increases the variance.

The increase in variance grows with the dimension. 

Dividing by sqrt (dimension) keeps the variance close to 1
    
</div>

In [16]:
import numpy as np

# Function to compute variance before and after scaling
def compute_variance(dim, num_trials=1000):
    dot_products = []
    scaled_dot_products = []

    # Generate multiple random vectors and compute dot products
    for _ in range(num_trials):
        q = np.random.randn(dim)
        k = np.random.randn(dim)
        
        # Compute dot product
        dot_product = np.dot(q, k)
        dot_products.append(dot_product)
        
        # Scale the dot product by sqrt(dim)
        scaled_dot_product = dot_product / np.sqrt(dim)
        scaled_dot_products.append(scaled_dot_product)
    
    # Calculate variance of the dot products
    variance_before_scaling = np.var(dot_products)
    variance_after_scaling = np.var(scaled_dot_products)

    return variance_before_scaling, variance_after_scaling

# For dimension 5
variance_before_5, variance_after_5 = compute_variance(5)
print(f"Variance before scaling (dim=5): {variance_before_5}")
print(f"Variance after scaling (dim=5): {variance_after_5}")

# For dimension 20
variance_before_100, variance_after_100 = compute_variance(100)
print(f"Variance before scaling (dim=100): {variance_before_100}")
print(f"Variance after scaling (dim=100): {variance_after_100}")



Variance before scaling (dim=5): 4.692901219967951
Variance after scaling (dim=5): 0.9385802439935901
Variance before scaling (dim=100): 98.83632170475772
Variance after scaling (dim=100): 0.9883632170475772


In [17]:
context_vec_2 = attn_weights_2 @ values
context_vec_2

tensor([0.3061, 0.8210])

In [None]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        # S = QKᵀ/√dₖ 
        attn_scores = queries @ keys.T 
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec

In [None]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in=d_in, d_out=d_out)
sa_v1.forward(inputs) # or sa_v1(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

### Why use `nn.Linear` instead of `nn.Parameter` in self-attention?

- **`nn.Parameter`**: Just a tensor marked as learnable.  
  - Gives full manual control (init, bias, forward math).  
  - But you must code all matrix multiplications yourself.  

- **`nn.Linear`**: A higher-level layer.  
  - Wraps weights (`nn.Parameter`) + optional bias.  
  - Provides the forward pass `x @ W^T + b` automatically.  
  - Handles parameter initialization and shape checks.  
  - Uses optimized backend kernels for speed.  

✅ In practice (e.g., Transformers/LLMs), we prefer `nn.Linear` for query/key/value projections because it is cleaner, less error-prone, and leverages PyTorch’s optimizations.  


In [None]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        # S = QKᵀ/√dₖ 
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [24]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
sa_v2(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)