In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Set seed for reproducibility
torch.manual_seed(123)

# SIMULATED INPUT
# Imagine we have a sentence with 3 words.
# Each word is embedded into a vector of size 2 (for simplicity).
# Shape: [Sequence_Length, Embedding_Dim]
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Word 1 (e.g., "The")
   [0.55, 0.87, 0.66], # Word 2 (e.g., "cat")
   [0.57, 0.85, 0.64]] # Word 3 (e.g., "sat")
)

# Dimensions
d_in = inputs.shape[1] # 3
d_out = 2              # We want to compress them to size 2

In [2]:
# 1. Select the 2nd token as our "Query"
query = inputs[1] 

# 2. Calculate the dot product between this query and all other inputs
# We create an empty list to store the scores
attn_scores_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    # The dot product measures similarity
    # (Row 1 dot Row i)
    attn_scores_2[i] = torch.dot(query, x_i)

print("Raw Similarity Scores:", attn_scores_2)

Raw Similarity Scores: tensor([0.9544, 1.4950, 1.4754])


In [3]:
# Normalize the scores so they act like probabilities
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention Weights:", attn_weights_2)
print("Sum check:", attn_weights_2.sum())

Attention Weights: tensor([0.2272, 0.3902, 0.3826])
Sum check: tensor(1.)


In [4]:
# Initialize a vector of zeros
context_vec_2 = torch.zeros(query.shape)

# Multiply every input vector by its attention weight and add them up
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i

print("Original 'cat' vector:", query)
print("New 'Context' vector: ", context_vec_2)

Original 'cat' vector: tensor([0.5500, 0.8700, 0.6600])
New 'Context' vector:  tensor([0.5304, 0.6987, 0.7046])


In [5]:
attn_scores = inputs @ inputs.T

In [6]:
# Apply softmax to the entire grid
attn_weights = torch.softmax(attn_scores, dim=-1)

print("Full Attention Weights:\n", attn_weights)
print("Row 1 (Cat) sums to:", attn_weights[1].sum())

Full Attention Weights:
 tensor([[0.3448, 0.3296, 0.3256],
        [0.2272, 0.3902, 0.3826],
        [0.2284, 0.3893, 0.3822]])
Row 1 (Cat) sums to: tensor(1.)


In [7]:
# Calculate context vectors for ALL words at once
all_context_vecs = attn_weights @ inputs

print("All Context Vectors shape:", all_context_vecs.shape)
print("Context Vector for 'cat':\n", all_context_vecs[1])

All Context Vectors shape: torch.Size([3, 3])
Context Vector for 'cat':
 tensor([0.5304, 0.6987, 0.7046])


In [8]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        
        # Define the 3 projection matrices (W_q, W_k, W_v)
        # We use nn.Linear because it handles the initialization for us
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key   = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        # 1. Calculate Q, K, V
        # x shape: [Batch, Seq_Len, d_in]
        keys = self.W_key(x)      # Shape: [Batch, Seq_Len, d_out]
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # 2. Calculate Attention Scores
        # We want (Queries @ Keys.T)
        # But wait! 'keys' is a batch, so we use 'transpose' carefully
        # Transpose dimensions 1 and 2 (Seq_Len and d_out)
        attn_scores = queries @ keys.transpose(1, 2)
        
        # 3. Compute Attention Weights
        # We normalize the scores
        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)
        
        # 4. Compute Context Vector
        context_vec = attn_weights @ values
        
        return context_vec

In [9]:
# Initialize the layer
d_in = 3   # Our input embedding size
d_out = 2  # Let's project them down to size 2

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)

# Pass our 'inputs' batch through the layer
# Note: We need to add a "Batch Dimension" first because nn.Linear expects it
# inputs shape: [3, 3] -> [1, 3, 3] (1 batch, 3 tokens, 3 dim)
batch_input = inputs.unsqueeze(0) 

print("Input shape:", batch_input.shape)

context_vecs = sa_v1(batch_input)
print("\nContext Vectors shape:", context_vecs.shape)
print("Context Vectors:\n", context_vecs)

Input shape: torch.Size([1, 3, 3])

Context Vectors shape: torch.Size([1, 3, 2])
Context Vectors:
 tensor([[[-0.6278, -0.0596],
         [-0.6301, -0.0633],
         [-0.6300, -0.0632]]], grad_fn=<UnsafeViewBackward0>)


In [10]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length=1024):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key   = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        
        # Create the mask once and store it
        # "register_buffer" tells PyTorch: "This is part of the model state, 
        # but it's not a trainable weight (don't update it with gradients)."
        self.register_buffer(
            'mask', 
            torch.tril(torch.ones(context_length, context_length))
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape 
        
        # 1. Calculate Q, K, V
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # 2. Calculate Unmasked Scores
        attn_scores = queries @ keys.transpose(1, 2)
        
        # 3. APPLY MASK (The New Logic)
        # We slice the mask to match the current sequence length (e.g., 3x3)
        mask_slice = self.mask[:num_tokens, :num_tokens]
        
        # Wherever the mask is 0, set score to -infinity
        attn_scores.masked_fill_(mask_slice == 0, -float('inf'))
        
        # 4. Softmax (The -inf becomes 0.0 here)
        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)
        
        # 5. Output
        context_vec = attn_weights @ values
        return context_vec

In [11]:
torch.manual_seed(123)
causal_attn = CausalAttention(d_in=3, d_out=2)

# Pass the input
batch_input = inputs.unsqueeze(0) # Ensure it's [1, 3, 3]
output = causal_attn(batch_input)

print("Output shape:", output.shape)

# Let's verify the weights manually to be sure
with torch.no_grad():
    q = causal_attn.W_query(batch_input)
    k = causal_attn.W_key(batch_input)
    scores = q @ k.transpose(1, 2)
    
    # Apply mask manually for inspection
    mask = torch.tril(torch.ones(3, 3))
    scores.masked_fill_(mask == 0, -float('inf'))
    weights = torch.softmax(scores / 2**0.5, dim=-1)
    
    print("\nAttention Weights (Row 0 = Word 1):\n", weights[0, 0])

Output shape: torch.Size([1, 3, 2])

Attention Weights (Row 0 = Word 1):
 tensor([1., 0., 0.])


In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # e.g., 256 / 4 = 64

        # 1. The Giant Linear Layers (Combined for all heads)
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key   = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        
        # 2. Output projection (to mix the heads back together)
        self.out_proj = nn.Linear(d_out, d_out)
        
        # 3. Mask
        self.register_buffer(
            'mask',
            torch.tril(torch.ones(context_length, context_length))
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        # 1. Project to Q, K, V
        keys = self.W_key(x)      # Shape: [b, num_tokens, d_out]
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # 2. Split the heads (Transform: [b, n, d_out] -> [b, n, num_heads, head_dim])
        # We perform a reshape and transpose to get heads into their own dimension
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        
        # Transpose: [b, num_tokens, num_heads, head_dim] -> [b, num_heads, num_tokens, head_dim]
        # This allows us to treat each head as a separate batch example
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # 3. Compute Attention Scores (Scaled Dot Product)
        # @ handles the broadcasting over the head dimension automatically
        attn_scores = queries @ keys.transpose(2, 3) 
        
        # 4. Masking
        # Create a boolean mask for the current sequence length
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(~mask_bool, -float('inf'))
        
        # 5. Softmax
        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
        
        # 6. Weighted Sum
        context_vec = attn_weights @ values # [b, heads, tokens, head_dim]
        
        # 7. Concatenate Heads
        # Transpose back: [b, tokens, heads, head_dim]
        context_vec = context_vec.transpose(1, 2)
        
        # Flatten back: [b, tokens, d_out]
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        
        # 8. Final Linear mixing
        return self.out_proj(context_vec)

In [14]:
torch.manual_seed(123)

# Configuration
d_in = 3
d_out = 4      # Total output size
num_heads = 2  # Split into 2 heads (so each head has dim 2)

mha = MultiHeadAttention(d_in, d_out, context_length=10, num_heads=num_heads)

batch_input = inputs.unsqueeze(0) # [1, 3, 3]

# Run forward pass
out = mha(batch_input)

print("Input shape: ", batch_input.shape)
print("Output shape:", out.shape)
print(out)

Input shape:  torch.Size([1, 3, 3])
Output shape: torch.Size([1, 3, 4])
tensor([[[ 0.1184,  0.3120, -0.0847, -0.5774],
         [ 0.0178,  0.3221, -0.0763, -0.4225],
         [-0.0147,  0.3259, -0.0734, -0.3721]]], grad_fn=<ViewBackward0>)
