In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Set seed for reproducibility
torch.manual_seed(123)

# SIMULATED INPUT
# Imagine we have a sentence with 3 words.
# Each word is embedded into a vector of size 2 (for simplicity).
# Shape: [Sequence_Length, Embedding_Dim]
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Word 1 (e.g., "The")
   [0.55, 0.87, 0.66], # Word 2 (e.g., "cat")
   [0.57, 0.85, 0.64]] # Word 3 (e.g., "sat")
)

# Dimensions
d_in = inputs.shape[1] # 3
d_out = 2              # We want to compress them to size 2

In [22]:
# 1. Select the 2nd token as our "Query"
query = inputs[1] 

# 2. Calculate the dot product between this query and all other inputs
# We create an empty list to store the scores
attn_scores_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    # The dot product measures similarity
    # (Row 1 dot Row i)
    attn_scores_2[i] = torch.dot(query, x_i)

print("Raw Similarity Scores:", attn_scores_2)

Raw Similarity Scores: tensor([0.9544, 1.4950, 1.4754])


In [23]:
# Normalize the scores so they act like probabilities
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention Weights:", attn_weights_2)
print("Sum check:", attn_weights_2.sum())

Attention Weights: tensor([0.2272, 0.3902, 0.3826])
Sum check: tensor(1.)


In [24]:
# Initialize a vector of zeros
context_vec_2 = torch.zeros(query.shape)

# Multiply every input vector by its attention weight and add them up
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i

print("Original 'cat' vector:", query)
print("New 'Context' vector: ", context_vec_2)

Original 'cat' vector: tensor([0.5500, 0.8700, 0.6600])
New 'Context' vector:  tensor([0.5304, 0.6987, 0.7046])


In [25]:
attn_scores = inputs @ inputs.T

In [26]:
# Apply softmax to the entire grid
attn_weights = torch.softmax(attn_scores, dim=-1)

print("Full Attention Weights:\n", attn_weights)
print("Row 1 (Cat) sums to:", attn_weights[1].sum())

Full Attention Weights:
 tensor([[0.3448, 0.3296, 0.3256],
        [0.2272, 0.3902, 0.3826],
        [0.2284, 0.3893, 0.3822]])
Row 1 (Cat) sums to: tensor(1.)


In [27]:
# Calculate context vectors for ALL words at once
all_context_vecs = attn_weights @ inputs

print("All Context Vectors shape:", all_context_vecs.shape)
print("Context Vector for 'cat':\n", all_context_vecs[1])

All Context Vectors shape: torch.Size([3, 3])
Context Vector for 'cat':
 tensor([0.5304, 0.6987, 0.7046])


In [None]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        
        # Define the 3 projection matrices (W_q, W_k, W_v)
        # We use nn.Linear because it handles the initialization for us
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key   = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        # 1. Calculate Q, K, V
        # x shape: [Batch, Seq_Len, d_in]
        keys = self.W_key(x)      # Shape: [Batch, Seq_Len, d_out]
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # 2. Calculate Attention Scores
        # We want (Queries @ Keys.T)
        # But wait! 'keys' is a batch, so we use 'transpose' carefully
        # Transpose dimensions 1 and 2 (Seq_Len and d_out)
        attn_scores = queries @ keys.transpose(1, 2)
        
        # 3. Compute Attention Weights
        # We normalize the scores
        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)
        
        # 4. Compute Context Vector
        context_vec = attn_weights @ values
        
        return context_vec