# Attention Mechanism Lab with Multihead Self-Attention in PyTorch

### What Is Attention?

Before diving into code, let’s understand why we need attention.

In NLP (Natural Language Processing), models must decide which parts of a sentence are important to look at. Attention lets each token focus more on relevant tokens in the sequence.

For example:

* In the sentence “The food was hot because it was spicy”, the word “it” should attend to “food”.

### Understanding Input Dimensions

We’re working with 3D tensors in this lab:

* **B**: Batch size (how many sentences we process at once)
* **T**: Number of tokens in each sentence
* **C**: Number of features (embeddings) for each token

###Part 1: Implementing a Basic Self-Attention Head

### What’s a Self-Attention Head?

A single attention head is a mechanism that allows a token to:

* **Query (Q)**: Ask a question (e.g., what am I looking for?)
* **Key (K)**: Answer if it has what the query wants
* **Value (V)**: Send information if the key matches

Tokens compare queries and keys using dot products and weigh each token’s value based on these scores.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(1337)

# B = batch, T = token length, C = channel/feature size
B, T, C = 2, 4, 8
x = torch.randn(B, T, C)  # Random input simulating embedded tokens

# Learnable linear layers for projecting to Q, K, V
key = nn.Linear(C, C, bias=False)
query = nn.Linear(C, C, bias=False)
value = nn.Linear(C, C, bias=False)

# Generate Q, K, V from input
k = key(x)
q = query(x)
v = value(x)

# Compute raw attention scores by dot product of Q and K
scores = q @ k.transpose(-2, -1) / C**0.5

# Mask to prevent attending to future tokens (for decoder-style attention)
mask = torch.tril(torch.ones(T, T))
scores = scores.masked_fill(mask == 0, float('-inf'))

# Normalize scores into probabilities
weights = F.softmax(scores, dim=-1)

# Weighted sum of values based on attention weights
out = weights @ v

# Print intermediate results (only for the first sample for brevity)
print("Input x (batch 0):\n", x[0])
print("\nQuery q (batch 0):\n", q[0])
print("\nKey k (batch 0):\n", k[0])
print("\nValue v (batch 0):\n", v[0])
print("\nAttention Scores (batch 0):\n", scores[0])
print("\nAttention Weights (batch 0):\n", weights[0])
print("\nFinal Output (batch 0):\n", out[0])

Input x (batch 0):
 tensor([[ 0.1808, -0.0700, -0.3596, -0.9152,  0.6258,  0.0255,  0.9545,  0.0643],
        [ 0.3612,  1.1679, -1.3499, -0.5102,  0.2360, -0.2398, -0.9211,  1.5433],
        [ 1.3488, -0.1396,  0.2858,  0.9651, -2.0371,  0.4931,  1.4870,  0.5910],
        [ 0.1260, -1.5627, -1.1601, -0.3348,  0.4478, -0.8016,  1.5236,  2.5086]])

Query q (batch 0):
 tensor([[ 0.0110, -0.1263,  0.4046,  0.0332,  0.2670, -0.1756,  0.1955, -0.0063],
        [ 0.6874, -0.0661, -0.2071,  0.3671,  0.6184, -0.5196,  0.4594, -0.2400],
        [ 0.0537,  0.4173,  0.5473,  0.2127, -0.3923,  0.7439,  0.6044, -0.9145],
        [-0.2882,  0.1333,  0.5673,  1.1221,  0.8594,  0.3488,  0.5764, -0.1420]],
       grad_fn=<SelectBackward0>)

Key k (batch 0):
 tensor([[ 0.1337,  0.1884,  0.2608,  0.0804, -0.4088,  0.2697, -0.0554,  0.2315],
        [ 0.7615,  0.6625, -0.1835,  0.1656,  0.8969,  0.4457, -0.2870,  0.5668],
        [-0.6749, -0.9168, -0.1317,  0.0897,  1.0475,  0.3863,  1.2753, -0.6169],
  

### Part 2: Making It a Reusable Class

Now let’s put the above logic into a reusable PyTorch module called `Head`. This will represent one attention head.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Head(nn.Module):
    """ One head of self-attention """

    def __init__(self, C_head):
        super().__init__()
        # Define linear projections for keys, queries, and values
        self.key = nn.Linear(C_head, C_head, bias=False)     # Projects input to 'key' vector
        self.query = nn.Linear(C_head, C_head, bias=False)   # Projects input to 'query' vector
        self.value = nn.Linear(C_head, C_head, bias=False)   # Projects input to 'value' vector

    def forward(self, x):
        # Input x: shape [B, T, C]
        B, T, C = x.shape  # B: batch size, T: sequence length, C: embedding dimension

        # Linear projections
        k = self.key(x)    # [B, T, C] → key vectors
        q = self.query(x)  # [B, T, C] → query vectors
        v = self.value(x)  # [B, T, C] → value vectors

        # Compute raw attention scores via scaled dot product
        # scores[i][j] = similarity between query i and key j
        # Shape: [B, T, T] where each score[i][j] is q[i] ⋅ k[j]
        scores = q @ k.transpose(-2, -1) / C**0.5  # Scale by sqrt(C) to stabilize gradients

        # Apply causal (lower triangular) mask to prevent attending to future tokens
        # Ensures token i only attends to token j ≤ i (causal self-attention)
        mask = torch.tril(torch.ones(T, T))  # Shape: [T, T]
        scores = scores.masked_fill(mask == 0, float('-inf'))  # Set future token scores to -inf

        # Softmax to normalize scores into attention weights
        # Each row of 'weights' sums to 1: it tells how much attention token i pays to j
        weights = F.softmax(scores, dim=-1)  # Shape: [B, T, T]

        # Final output: each token is a weighted sum of the value vectors
        # Attention-weighted sum across time dimension (T)
        out = weights @ v  # Shape: [B, T, C]

        return out


### Part 3: Multihead Attention

### Why Multiple Heads?

Each head learns a different perspective:

* One may focus on grammar
* Another on meaning
* Another on position

Multihead attention allows parallel learning from multiple viewpoints.


In [None]:
class MultiHeadAttention(nn.Module):
    """ Multiple heads of self-attention in parallel """

    def __init__(self, C, num_heads):
        super().__init__()
        assert C % num_heads == 0, "Embedding dimension must be divisible by num_heads"

        # Create a list of independent attention heads
        # Each head processes a portion of the input features
        self.C = C
        self.num_heads = num_heads
        self.C_head = C // num_heads

        self.heads = nn.ModuleList([
            Head(self.C_head) for _ in range(num_heads)
        ])
        self.proj = nn.Linear(C, C)

    def forward(self, x):
        # x: [B, T, C]
        head_outputs = []
        for i, head in enumerate(self.heads):
            # Slice feature dimension for each head: [B, T, C_head]
            x_split = x[..., i * self.C_head : (i + 1) * self.C_head]
            head_outputs.append(head(x_split))

        # Concatenate along feature dimension: [B, T, C]
        out = torch.cat(head_outputs, dim=-1)
        return self.proj(out)



### Part 4: Simple Language Model Using Attention

We’ll now build a basic language model that learns to predict the next character given the previous ones.

### Key Components:

* **Token Embedding**: Turns characters into vectors
* **Positional Encoding**: Tells the model where each token is
* **Multihead Attention**: Helps tokens attend to previous ones
* **Output Layer**: Predicts the next token

In [None]:
class SimpleLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, n_embd, n_head):
        super().__init__()

        # Embedding for tokens (words or subwords)
        # Maps each token index to a learnable n_embd-dimensional vector
        self.token_emb = nn.Embedding(vocab_size, n_embd)

        # Embedding for positions (e.g., positions 0, 1, ..., block_size-1)
        # Gives the model a sense of order in the sequence
        self.pos_emb = nn.Embedding(block_size, n_embd)

        # Multi-head self-attention block (causal)
        # Learns dependencies between tokens at different positions
        self.attn = MultiHeadAttention(n_embd, n_head)

        # Final projection layer to output logits over the vocabulary
        self.output = nn.Linear(n_embd, vocab_size)

    def forward(self, idx):
        # idx: [B, T] — input token indices (B = batch size, T = sequence length)

        B, T = idx.shape  # Extract dimensions

        # Token embeddings: shape [B, T, n_embd]
        tok = self.token_emb(idx)

        # Positional embeddings: shape [T, n_embd], then broadcast to [B, T, n_embd]
        pos = self.pos_emb(torch.arange(T, device=idx.device))

        # Combine token and position embeddings
        x = tok + pos  # Shape: [B, T, n_embd]

        # Pass through multi-head self-attention block
        x = self.attn(x)  # Shape: [B, T, n_embd]

        # Final linear layer to map embeddings to vocabulary logits
        logits = self.output(x)  # Shape: [B, T, vocab_size]

        return logits  # These logits are used to compute softmax and loss externally

### Part 5: Training on Simple Text

Let’s train this model on a very small example: "hello world".

In [None]:
text = "hello world"  # The input text corpus for training a toy language model

# Extract all unique characters, sort them (just for consistency)
chars = sorted(set(text))  # [' ', 'd', 'e', 'h', 'l', 'o', 'r', 'w']

# Create a mapping from character to integer (index)
stoi = {ch: i for i, ch in enumerate(chars)}  # "string to index"
itos = {i: ch for ch, i in stoi.items()}       # "index to string"

# Define encoding and decoding functions
encode = lambda s: [stoi[c] for c in s]        # Converts string → list of integers
decode = lambda l: ''.join([itos[i] for i in l])  # Converts list of integers → string

# Convert the entire string into a tensor of token indices
data = torch.tensor(encode(text), dtype=torch.long)
# For "hello world", this will look like: tensor([3, 2, 4, 4, 5, 0, 7, 5, 6, 4, 1])

def get_batch(batch_size=4):
    # Randomly pick 4 starting indices for 8-token long sequences
    ix = torch.randint(0, len(data) - 8, (4,))
    # For each index `i`, get a chunk of 8 tokens as input
    x = torch.stack([data[i:i+8] for i in ix])  # Shape: [4, 8]

    # For each input chunk, the corresponding target is the next 8 tokens
    y = torch.stack([data[i+1:i+9] for i in ix])  # Shape: [4, 8]

    return x, y  # x is input, y is expected output (next character)


### Part 6: Training Loop

Train the model on small sequences of characters:

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleLanguageModel(len(chars), block_size=8, n_embd=16, n_head=2).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)

for step in range(200):
    x, y = get_batch(batch_size=64)
    x, y = x.to(device), y.to(device)
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
    opt.zero_grad()
    loss.backward()
    opt.step()
    if step % 50 == 0:
        print(f"Step {step}, Loss: {loss.item():.4f}")


Step 0, Loss: 2.1783
Step 50, Loss: 1.5632
Step 100, Loss: 0.8687
Step 150, Loss: 0.3847


### Part 7: Text Generation

Generate new text starting from a seed character.

Try changing the seed charecter to understand the text generation mechanism.

In [None]:
@torch.no_grad()
def generate(model, start_str, max_new_tokens=100):
    model.eval()

    # Encode the start string into tensor of token indices
    idx = torch.tensor([encode(start_str)], dtype=torch.long).to(device)  # Shape: [1, T]

    for _ in range(max_new_tokens):
        # If the input is longer than block_size, truncate from the left
        idx_cond = idx[:, -8:]

        # Get model predictions
        logits = model(idx_cond)  # [1, T, vocab_size]
        last_logits = logits[:, -1, :]  # Take the last time step's logits: [1, vocab_size]

        # Convert logits to probabilities
        probs = F.softmax(last_logits, dim=-1)  # [1, vocab_size]

        # Sample the next token (greedy or random)
        next_token = torch.multinomial(probs, num_samples=1)  # [1, 1]

        # Append the next token to the sequence
        idx = torch.cat([idx, next_token], dim=1)  # [1, T+1]

    # Decode the entire sequence to string
    return decode(idx[0].tolist())

print(generate(model, start_str="w", max_new_tokens=100))


wrllo worldorlhld llwllwllorlhrllllldorldorlhld ldorlorl llllllllo lo worldlllrlhlworldorlrllldorlllo
