In [1]:
import torch
import torch.nn.functional as F
import math

# For better display in notebooks
from IPython.display import display, Markdown

In [3]:
# Define embedding dimension (dk)
EMBEDDING_DIM = 4

# Define sequence length (number of tokens/words)
SEQUENCE_LENGTH = 3

# --- Simulate Input Embeddings (Our "Value" vectors initially) ---
# In a real model, these would come from an embedding layer
# or previous transformer block.
# Let's represent 5 tokens, each with an embedding of 4 dimensions.
# Shape: (sequence_length, embedding_dim)
embedded_tokens = torch.tensor([
    [0.1, 0.2, 0.3, 0.4], # Token 0: "He"
    [0.5, 0.6, 0.7, 0.8], # Token 1: "is"
    [0.9, 0.0, 0.1, 0.2], # Token 2: "awesome"
], dtype=torch.float32)

display(Markdown(f"**`embedded_tokens` (simulated input/value vectors) Shape:** `{embedded_tokens.shape}`"))
display(embedded_tokens)
print("\n---\n")

# --- Simulate Query and Key Matrices ---
# In self-attention, Q, K, V typically come from linear transformations
# of the same input embeddings. For simplicity, let's just make Q, K same as V for now.
# In a real scenario, you'd have:
# Q = embedded_tokens @ W_q
# K = embedded_tokens @ W_k
# V = embedded_tokens @ W_v
# where W_q, W_k, W_v are weight matrices.

**`embedded_tokens` (simulated input/value vectors) Shape:** `torch.Size([3, 4])`

tensor([[0.1000, 0.2000, 0.3000, 0.4000],
        [0.5000, 0.6000, 0.7000, 0.8000],
        [0.9000, 0.0000, 0.1000, 0.2000]])


---



In [5]:
# For this example, let's make Q and K identical to embedded_tokens
# to demonstrate self-attention where each token attends to all others including itself.
queries = embedded_tokens
keys = embedded_tokens
values = embedded_tokens # Renaming for clarity as per Q,K,V convention

display(Markdown(f"**`queries` Shape:** `{queries.shape}`"))
display(queries)
display(Markdown(f"**`keys` Shape:** `{keys.shape}`"))
display(keys)
display(Markdown(f"**`values` (same as embedded_tokens for this example) Shape:** `{values.shape}`"))
display(values)

**`queries` Shape:** `torch.Size([3, 4])`

tensor([[0.1000, 0.2000, 0.3000, 0.4000],
        [0.5000, 0.6000, 0.7000, 0.8000],
        [0.9000, 0.0000, 0.1000, 0.2000]])

**`keys` Shape:** `torch.Size([3, 4])`

tensor([[0.1000, 0.2000, 0.3000, 0.4000],
        [0.5000, 0.6000, 0.7000, 0.8000],
        [0.9000, 0.0000, 0.1000, 0.2000]])

**`values` (same as embedded_tokens for this example) Shape:** `torch.Size([3, 4])`

tensor([[0.1000, 0.2000, 0.3000, 0.4000],
        [0.5000, 0.6000, 0.7000, 0.8000],
        [0.9000, 0.0000, 0.1000, 0.2000]])

In [8]:
# QK^T
# queries shape: (seq_len, embedding_dim)
# keys.T shape: (embedding_dim, seq_len)
# Resulting attn_weights shape: (seq_len, seq_len)
# Element (i, j) will be the dot product of query_i and key_j
attn_weights = torch.matmul(queries, keys.transpose(-2, -1))
attn_weights 

tensor([[0.3000, 0.7000, 0.2000],
        [0.7000, 1.7400, 0.6800],
        [0.2000, 0.6800, 0.8600]])

In [9]:
display(Markdown(f"**Raw Attention Weights ($QK^T$) Shape:** `{attn_weights.shape}`"))
display(attn_weights)
print("\n---")
display(Markdown("Each row corresponds to a Query, and each column to a Key."))
display(Markdown("For example, `attn_weights[0, 1]` is the similarity between Query 0 and Key 1."))

**Raw Attention Weights ($QK^T$) Shape:** `torch.Size([3, 3])`

tensor([[0.3000, 0.7000, 0.2000],
        [0.7000, 1.7400, 0.6800],
        [0.2000, 0.6800, 0.8600]])


---


Each row corresponds to a Query, and each column to a Key.

For example, `attn_weights[0, 1]` is the similarity between Query 0 and Key 1.

In [10]:
#Scaling
scaling_factor = math.sqrt(EMBEDDING_DIM)
attn_weights_scaled = attn_weights / scaling_factor

display(Markdown(f"**Scaling Factor ($\sqrt{{d_k}}$):** `{scaling_factor:.4f}`"))
display(Markdown(f"**Scaled Attention Weights Shape:** `{attn_weights_scaled.shape}`"))
display(attn_weights_scaled)
print("\n---")
display(Markdown("Notice how the values are now smaller, which helps prevent vanishing gradients after softmax."))

  display(Markdown(f"**Scaling Factor ($\sqrt{{d_k}}$):** `{scaling_factor:.4f}`"))


**Scaling Factor ($\sqrt{d_k}$):** `2.0000`

**Scaled Attention Weights Shape:** `torch.Size([3, 3])`

tensor([[0.1500, 0.3500, 0.1000],
        [0.3500, 0.8700, 0.3400],
        [0.1000, 0.3400, 0.4300]])


---


Notice how the values are now smaller, which helps prevent vanishing gradients after softmax.

In [12]:
# Apply softmax along the last dimension (dim=1 for a 2D tensor of (queries, keys))
# This ensures that for each query (row), the attention weights across all keys (columns) sum to 1.
attn_weights_norm = F.softmax(attn_weights_scaled, dim=1)

display(Markdown(f"**Normalized Attention Weights (Softmax Output) Shape:** `{attn_weights_norm.shape}`"))
display(attn_weights_norm)
print("\n---")
display(Markdown("Each row now represents a probability distribution. Let's check a row's sum:"))
print(f"Sum of first row: {attn_weights_norm[0].sum().item():.4f}")
print(f"Sum of second row: {attn_weights_norm[1].sum().item():.4f}")
print(f"Sum of third row: {attn_weights_norm[2].sum().item():.4f}")

**Normalized Attention Weights (Softmax Output) Shape:** `torch.Size([3, 3])`

tensor([[0.3152, 0.3850, 0.2998],
        [0.2723, 0.4581, 0.2696],
        [0.2731, 0.3471, 0.3798]])


---


Each row now represents a probability distribution. Let's check a row's sum:

Sum of first row: 1.0000
Sum of second row: 1.0000
Sum of third row: 1.0000


In [13]:
attn_weights_norm

tensor([[0.3152, 0.3850, 0.2998],
        [0.2723, 0.4581, 0.2696],
        [0.2731, 0.3471, 0.3798]])

In [14]:
# attn_weights_norm shape: (num_queries, num_keys) -> (SEQUENCE_LENGTH, SEQUENCE_LENGTH)
# values shape: (num_keys, embedding_dim) -> (SEQUENCE_LENGTH, EMBEDDING_DIM)
# Resulting context_weighted_embeddings shape: (num_queries, embedding_dim)
context_weighted_embeddings = torch.matmul(attn_weights_norm, values)

display(Markdown(f"**Context-Weighted Embeddings Shape:** `{context_weighted_embeddings.shape}`"))
display(context_weighted_embeddings)
print("\n---")
display(Markdown("Each row is a new, context-aware representation for the corresponding input token."))

**Context-Weighted Embeddings Shape:** `torch.Size([3, 4])`

tensor([[0.4939, 0.2940, 0.3940, 0.4940],
        [0.4989, 0.3293, 0.4293, 0.5293],
        [0.5427, 0.2629, 0.3629, 0.4629]])


---


Each row is a new, context-aware representation for the corresponding input token.

In [16]:
display(Markdown(f"**Manual Calculation for 4th Query (Index 2):**"))
display(Markdown(f"The 4th query's attention weights are: `{attn_weights_norm[2]}`"))
display(Markdown(f"The original value vectors (`embedded_tokens` are: "))
display(values)

**Manual Calculation for 4th Query (Index 2):**

The 4th query's attention weights are: `tensor([0.2731, 0.3471, 0.3798])`

The original value vectors (`embedded_tokens` are: 

tensor([[0.1000, 0.2000, 0.3000, 0.4000],
        [0.5000, 0.6000, 0.7000, 0.8000],
        [0.9000, 0.0000, 0.1000, 0.2000]])

In [21]:
context_weighted_embeddings_2_manual = (
    attn_weights_norm[2, 0] * values[0] +
    attn_weights_norm[2, 1] * values[1] +
    attn_weights_norm[2, 2] * values[2]
)

display(Markdown(f"**Manually Calculated Context for 3rd Query:** `{context_weighted_embeddings_2_manual}`"))
display(Markdown(f"**PyTorch Calculated Context for 3rd Query:** `{context_weighted_embeddings[2]}`"))

**Manually Calculated Context for 3rd Query:** `tensor([0.5427, 0.2629, 0.3629, 0.4629])`

**PyTorch Calculated Context for 3rd Query:** `tensor([0.5427, 0.2629, 0.3629, 0.4629])`

In [23]:
# Check if they are approximately equal
if torch.allclose(context_weighted_embeddings_2_manual, context_weighted_embeddings[2]):
    display(Markdown("The manual calculation matches the `torch.matmul` result!"))
else:
    display(Markdown("Mismatch detected (shouldn't happen if numbers are floats, might be tiny precision diff)."))

The manual calculation matches the `torch.matmul` result!