# Multi-head attention mechanism combined with data loading

The complete chapter code is in [ch03.ipynb](./ch03.ipynb).

This notebook contains the main learning point, namely the implementation of multi-head attention (and the data loading process in Chapter 2).

## Data loader in Chapter 2

In [1]:
# Import necessary libraries
import tiktoken  # 假设这是一个自定义的分词库
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Define the GPT dataset class
class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
# Initialize the tokenizer
        self.tokenizer = tokenizer
# Initialize input and target ID lists
        self.input_ids = []
        self.target_ids = []

# Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})

# Use a sliding window to split the text into overlapping sequences of length max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
# Convert the token ID to a PyTorch tensor and add it to the list
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
# Return the number of samples in the dataset
        return len(self.input_ids)

    def __getitem__(self, idx):
# Return the corresponding input and target tensors according to the index
        return self.input_ids[idx], self.target_ids[idx]

# Create a data loader function
def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
# Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

# Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

# Create a data loader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return dataloader

# Read text file
with open("small-text-sample.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
# Encode the original text
encoded_text = tokenizer.encode(raw_text)

# Define vocabulary size, output dimension, maximum length and block size
vocab_size = 50257
output_dim = 256
max_len = 1024
block_size = max_len

# Create word embedding layer and position embedding layer
token_embedding_layer = nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)

# Setting the maximum length to 4 is probably a mistake, since usually the maximum length will be greater than 4
max_length = 4
# Create a data loader
dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)

In [2]:
# Iterate over each batch in the data loader
for batch in dataloader:
# Unpack input and target data from the current batch
    x, y = batch

# Use the word embedding layer to calculate the word embedding of the input sequence
    token_embeddings = token_embedding_layer(x)
# Use the position embedding layer to calculate the position embedding, here we use torch.arange to create a sequence of the same length as max_length
    pos_embeddings = pos_embedding_layer(torch.arange(max_length))

# Add word embedding and position embedding to get the final input embedding
    input_embeddings = token_embeddings + pos_embeddings

# Break out of the loop. This may be for demonstration purposes. In actual training, the loop will usually continue until all data are traversed.
    break

In [3]:
print(input_embeddings.shape)

torch.Size([8, 4, 256])


# Multi-head attention in Chapter 3

## Method A: Simple implementation

In [4]:
# Define causal self-attention module
class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):
# Call the parent class constructor
        super().__init__()
# Output dimensions
        self.d_out = d_out
# Linear transformation layer for query, key and value
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
# Dropout layer, used for regularization
        self.dropout = nn.Dropout(dropout)
# Register a buffer to store the upper triangle mask for causal self-attention
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))

    def forward(self, x):
# Get the batch size, sequence length, and input dimension of the input tensor
        b, n_tokens, d_in = x.shape
# Calculate query, key and value separately
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

# Calculate the attention score, using the transpose operation
        attn_scores = queries @ keys.transpose(1, 2)
# Use mask to set the attention score of future positions to negative infinity to achieve causal self-attention
        attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens], -torch.inf)
# Normalize attention scores
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
# Apply dropout
        attn_weights = self.dropout(attn_weights)

# Calculate the context vector
        context_vec = attn_weights @ values
        return context_vec

# Define the multi-head attention wrapper module
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
# Call the parent class constructor
        super().__init__()
# Create multiple causal self-attention modules
        self.heads = nn.ModuleList([
            CausalSelfAttention(d_in, d_out, block_size, dropout, qkv_bias) 
            for _ in range(num_heads)
        ])
# Output projection layer, used to merge the outputs of multiple heads
        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)

    def forward(self, x):
# Concatenate the outputs of all heads along the last dimension
        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
# Through the output projection layer
        return self.out_proj(context_vec)

In [5]:
# Set the random seed to ensure reproducibility of results
torch.manual_seed(123)

# Define the block size, which is the same as the maximum length
block_size = max_length
# Input dimensions
d_in = output_dim

# Define the output dimension of each head in multi-head attention
num_heads = 2
# Output dimension is input dimension divided by number of heads
d_out = d_in // num_heads

# Initialize the multi-head attention wrapper
mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads)

# Assume input_embeddings is the input data prepared before
batch = input_embeddings
# Use the multi-head attention module to process input data
context_vecs = mha(batch)

# Print the dimensions of the context vector
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([8, 4, 256])


## Method B: Alternative Implementation

In [6]:
# Define multi-head self-attention module
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
# Call the parent class constructor
        super().__init__()
# Make sure the output dimension is divisible by the number of heads
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

# Initialize module properties
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # 计算每个头的维度
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)  # 查询线性层
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)    # 键线性层
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)  # 值线性层
        self.out_proj = nn.Linear(d_out, d_out)  # 输出投影层
        self.dropout = nn.Dropout(dropout)  # Dropout层
# Register a buffer to store the upper triangle mask for causal self-attention
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))

    def forward(self, x):
# Get the batch size, sequence length, and input dimension of the input tensor
        b, num_tokens, d_in = x.shape

# Calculate query, key and value separately
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

# Split the matrix by the number of heads and add a dimension
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose to match the dimensions of multi-head attention
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

# Compute scaled dot product attention scores and apply causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # 对每个头进行点积
# Truncate the mask to match the sequence length and convert to a boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Expand the mask to match the dimensions
        mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
# Fill the attention scores with the expanded mask
        attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
        
# Normalize attention scores
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

# Calculate the context vector
        context_vec = (attn_weights @ values).transpose(1, 2)
        
# Merge header output
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
# Optional output projection
        context_vec = self.out_proj(context_vec)

        return context_vec

In [7]:
# Set the random seed to ensure reproducibility of results
torch.manual_seed(123)

# Define the block size, which is the same as the maximum length
block_size = max_length
# Input and output dimensions
d_in = output_dim
# Set the output dimension to be the same as the input dimension
d_out = d_in

# Initialize the multi-head self-attention module
mha = MultiHeadAttention(d_in, d_out, block_size, dropout=0.0, num_heads=2)

# Assume input_embeddings is the input data prepared before
batch = input_embeddings
# Use the multi-head self-attention module to process input data
context_vecs = mha(batch)

# Print the dimensions of the context vector
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([8, 4, 256])
