# Multi-head Attention Plus Data Loading

The complete chapter code is located in [ch03.ipynb](./ch03.ipynb).

This notebook contains the main takeaway, multihead-attention implementation (plus the data loading pipeline from chapter 2)

## Data Loader from Chapter 2

In [1]:
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.target_ids = []

        # Tokenize the entire text
        token_ids = tokenizer.encode(txt)

        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader(txt, batch_size=4, max_length=256, stride=128):
    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size)

    return dataloader


with open("small-text-sample.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

tokenizer = tiktoken.get_encoding("gpt2")
encoded_text = tokenizer.encode(raw_text)

vocab_size = 50257
output_dim = 256
max_len = 1024
block_size = max_len


token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)

max_length = 4
dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)

In [2]:
for batch in dataloader:
    x, y = batch

    token_embeddings = token_embedding_layer(x)
    pos_embeddings = pos_embedding_layer(torch.arange(max_length))

    input_embeddings = token_embeddings + pos_embeddings

    break

In [3]:
print(input_embeddings.shape)

torch.Size([8, 4, 256])


# Multi-head Attention from Chapter 3

## Variant A: Simple implementation

In [4]:
import torch

class CausalSelfAttention(torch.nn.Module):

    def __init__(self, d_in, d_out, block_size, dropout):
        super().__init__()
        self.d_out = d_out
        self.W_query = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_key   = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=False)
        self.dropout = torch.nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New

    def forward(self, x):
        b, n_tokens, d_in = x.shape # New batch dimension b
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:n_tokens, :n_tokens], -torch.inf) 
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec


class MultiHeadAttentionWrapper(torch.nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads):
        super().__init__()
        self.heads = torch.nn.ModuleList(
            [CausalSelfAttention(d_in, d_out, block_size, dropout) 
             for _ in range(num_heads)]
        )
        self.out_proj = torch.nn.Linear(d_out*num_heads, d_out*num_heads)

    def forward(self, x):
        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.out_proj(context_vec)

In [5]:
torch.manual_seed(123)

block_size = max_length
d_in = output_dim

num_heads=2
d_out = d_in // num_heads

mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads)

batch = input_embeddings
context_vecs = mha(batch)

print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([8, 4, 256])


## Variant B: Alternative implementation

In [6]:
import torch


class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=False)
        self.out_proj = torch.nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))

    def forward(self, x):
        b, n_tokens, d_in = x.shape

        # Split into multiple heads
        keys = self.W_key(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        queries = self.W_query(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        values = self.W_value(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute scaled dot-product attention for each head
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
        attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, T, n_heads, head_dim)
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, n_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [7]:
torch.manual_seed(123)

block_size = max_length
d_in = output_dim
d_out = d_in

mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)

batch = input_embeddings
context_vecs = mha(batch)

print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([8, 4, 256])
