###### Multi-head Attention Plus Data Loading

In [None]:
# NBVAL_IGNORE_OUTPUT

from importlib.metadata import version

print("torch version:", version("torch"))

###### The complete chapter code is located in ch03.ipynb.

###### This notebook contains the main takeaway, multihead-attention implementation (plus the data loading pipeline from chapter 2)

#### Data Loader from Chapter 2

In [None]:
import tiktoken
mport torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class GPTDatasetV1(Dataset):
   def __init__(self, txt, tokenizer, max_length, stride):
       self.input_ids = []
       self.target_ids = []

       # Tokenize the entire text
       token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})

       # Use a sliding window to chunk the book into overlapping sequences of max_length
       for i in range(0, len(token_ids)-max_length, stride):
           input_chunk = token_ids[i: i + max_length]
           target_chunk = token_ids[i+1: i+max_length+1]
           self.input_ids.append(torch.tensor(input_chunk))
           self.target_ids.append(torch.tensor(target_chunk))

   def __len__(self):
      return len(self.input_ids)

   def __getitem__(self, idx):
      return self.input_ids[idx], self.target_ids[idx]

def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return dataloader

with open("small-text-sample.txt", "r", encoding="utf-8") as f:
     raw_text = f.read()

tokenizer = tiktoken.get_encoding("gpt2")
encoded_text = tokenizer.encode(raw_text)

vocab_size = 50257
output_dim = 256
max_len = 1024
context_length = max_len

token_embedding_layer = nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

max_length = 4
dataloader = create_dataloader(raw_text, batch_size=0, max_length=max_length, stride=max_length)

In [None]:
for batch in dataloader:
    x, y = batch

    token_embeddings = token_embedding_layer(x)
    pos_embeddings = pos_embedding_layer(torch.arange(max_length))

    input_embeddings = token_embeddings + pos_embeddings

    break

In [None]:
print(input_embeddings.shape)

#### Multi-head Attention from Chapter 3
##### Variant A: Simple implementation

In [None]:
class CausalAttention(nn.Module):
      def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
          super().__init__()
          self.d_out = d_out
          self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
          self.dropout = nn.Dropout(dropout)  # New
          self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))  # New

       def forward(self, x):
           b, n_tokens, d_in = x.shape  # New batch dimension b
           keys = self.W_key(x)
           queries = self.W_query(x)
           values = self.W_value(x)

           attn_scores = queries @ keys.transpose(1, 2) # changed transpose
           attn_scores.masked_fill_(   # New, _ops are in place
               self.mask.bool()[:n_tokens, :n_tokens], -torch.inf
           )
           attn_weights = torch.softmax()
