# Creating Tokens from Text

In [1]:
with open("../ch02/01_main-chapter-code/the-verdict.txt") as f:
    raw_text = f.read()
print("Total number of characters in the text file: ", len(raw_text))
print(raw_text[:99])

Total number of characters in the text file:  20479
I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no 


## Tokenization

### Getting Vocabulary

In [2]:
import re
result = re.split(r'([,.:;?_!"()\']|--|\s)', raw_text)
result = [item for item in result if item.strip()] # Remove space strings
print(result[:30])
print(len(result))

['I', 'HAD', 'always', 'thought', 'Jack', 'Gisburn', 'rather', 'a', 'cheap', 'genius', '--', 'though', 'a', 'good', 'fellow', 'enough', '--', 'so', 'it', 'was', 'no', 'great', 'surprise', 'to', 'me', 'to', 'hear', 'that', ',', 'in']
4690


In [3]:
all_words = sorted(set(result))  # Get unique words and sort them
vocab_size = len(all_words)
print("Total number of unique words in the text file: ", vocab_size)

Total number of unique words in the text file:  1130


In [4]:
vocab = {token: integer for integer, token in enumerate(all_words)}  # Create a dictionary of tokens and their IDs
for i, item in enumerate(list(vocab.items())[-5:]):
    print(item)
    if i > 10:
        break

('yet', 1125)
('you', 1126)
('younger', 1127)
('your', 1128)
('yourself', 1129)


### Implementing a Simple Text Tokenizer

In [5]:
class SimpleTokenizerV1:
    def __init__(self, vocabulary): # We must have a vocabulary to initialize the tokenizer
        self.str_to_int = vocabulary  # Dictionary of tokens and their IDs {"token1": 0, ...}
        self.int_to_str = {v: k for k, v in vocabulary.items()}  # Dictionary of IDs and their tokens {0: "token1", ...}

    def encode(self, text):
        # Convert text to a list of token IDs
        preprocessed_text = re.split(r'([,.:;?_!"()\']|--|\s)', text)
        preprocessed_text = [item.strip() for item in preprocessed_text if item.strip()]
        ids = [self.str_to_int[token] for token in preprocessed_text]
        return ids  # return a list of token IDs

    def decode(self, ids):
        # Convert a list of token IDs to text with spaces
        text = " ".join([self.int_to_str[id] for id in ids])

        text = re.sub(r'\s+([,.?!"()\'])', r'\1', text)  # Remove spaces before punctuation
        return text

In [6]:
tokenizer = SimpleTokenizerV1(vocab)
text = """"It's the last he painted, you know," Mrs. Gisburn said with pardonable pride."""
ids = tokenizer.encode(text)
print(ids)
print(tokenizer.decode(ids))

[1, 56, 2, 850, 988, 602, 533, 746, 5, 1126, 596, 5, 1, 67, 7, 38, 851, 1108, 754, 793, 7]
" It' s the last he painted, you know," Mrs. Gisburn said with pardonable pride.


### Implementing a Simple Text Tokenizer that Handles Unknown Tokens

In [7]:
all_words.extend(["<|endoftext|>", "<|unk|>"])  # Add special tokens for end of text and unknown words
vocab = {token:integer for integer,token in enumerate(all_words)}  # Create a new vocabulary

for i, item in enumerate(list(vocab.items())[-5:]):
    print(item)

('younger', 1127)
('your', 1128)
('yourself', 1129)
('<|endoftext|>', 1130)
('<|unk|>', 1131)


In [8]:
class SimpleTokenizerV2:
    def __init__(self, vocabulary): # We must have a vocabulary to initialize the tokenizer
        self.str_to_int = vocabulary  # Dictionary of tokens and their IDs {"token1": 0, ...}
        self.int_to_str = {v: k for k, v in vocabulary.items()}  # Dictionary of IDs and their tokens {0: "token1", ...}

    def encode(self, text) -> list:
        preprocessed_tokens = re.split(r'([,.:;?_!"()\']|--|\s)', text)
        preprocessed_tokens = [item.strip() for item in preprocessed_tokens if item.strip()]  # remove empty space strings and empty strings
        preprocessed_tokens = [item if item in self.str_to_int else "<|unk|>" for item in preprocessed_tokens]  # Replace unknown words from source text with <|unk|>

        ids = [self.str_to_int[token] for token in preprocessed_tokens]
        return ids

    def decode(self, ids) -> str:
        decoded_text = " ".join([self.int_to_str[i] for i in ids])
        decoded_text = re.sub(r'\s+([,.?!"()\'])', r'\1', decoded_text) # Remove spaces before punctuation
        return decoded_text

In [9]:
text1 = "Hello, do you like tea?"
text2 = "In the sunlit terraces of the palace."
text = " <|endoftext|> ".join((text1, text2))
print(text)

tokenizer = SimpleTokenizerV2(vocab)
print(tokenizer.encode(text))
print(tokenizer.decode(tokenizer.encode(text)))

Hello, do you like tea? <|endoftext|> In the sunlit terraces of the palace.
[1131, 5, 355, 1126, 628, 975, 10, 1130, 55, 988, 956, 984, 722, 988, 1131, 7]
<|unk|>, do you like tea? <|endoftext|> In the sunlit terraces of the <|unk|>.


### Byte Pair Encoding (BPE): A Tokenzier Which can Handle Any Unknown Words

In [10]:
# Use tiktoken library to implement BPE
import tiktoken
from importlib.metadata import version
print(version("tiktoken"))

0.5.2


In [11]:
tokenizer = tiktoken.get_encoding("gpt2")

In [12]:
text = "Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace. Akwirw ier."
ids = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
print(ids)

text = tokenizer.decode(ids)
print(text)

[15496, 11, 466, 345, 588, 8887, 30, 220, 50256, 554, 262, 4252, 18250, 8812, 2114, 286, 617, 34680, 27271, 13, 9084, 86, 343, 86, 220, 959, 13]
Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace. Akwirw ier.


# Creating Input and Targets Data Loader Using PyTorch

## Data Sampling with A Sliding Window

In [13]:
# Encode the text using the BPE tokenizer
enc_text = tokenizer.encode(raw_text)
print(len(enc_text))

5145


In [14]:
# Take a sample of encoded text
enc_sample = enc_text[50:]

# Create Input/Target Pairs by using a sliding window
context_size = 4
x = enc_sample[:context_size] # Input is the first 4 tokens of sampled text [0:4]
y = enc_sample[1:context_size+1] # Target is the next 4 tokens of sampled text [1:5] (Right shift by 1)
print(f"x: {x}")
print(f"y:      {y}\n")

# context and desired are encoded token ids (integers)
for i in range(1, context_size+1):
    context = enc_sample[:i]
    desired = enc_sample[i]
    print(context, "---->", desired)

print('\n')

# Do the above loop again with the decoded token ids
for i in range(1, context_size+1):
    context = enc_sample[:i]  # list
    desired = enc_sample[i]  # integer, so need to convert to list
    print(tokenizer.decode(context), "---->", tokenizer.decode([desired]))

x: [290, 4920, 2241, 287]
y:      [4920, 2241, 287, 257]

[290] ----> 4920
[290, 4920] ----> 2241
[290, 4920, 2241] ----> 287
[290, 4920, 2241, 287] ----> 257


 and ---->  established
 and established ---->  himself
 and established himself ---->  in
 and established himself in ---->  a


## Using a PyTorch Data Loader To Create Input and Target Tensors

In [15]:
import torch
from torch.utils.data import Dataset, DataLoader

from torch import Tensor

In [16]:
# Define the Dataset class
class GPTDataSetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        """
        :param txt: source text
        :param tokenizer: tokenizer object
        :param max_length: number of token ids in a chunk
        :param stride: shift across batches. For example, input/target pair batch 1 is [0:4]/[1:5], batch 2 is [1:5]/[2:6], etc.
        """
        self.input_ids = []  # A list of tensors, or a multi-dimension array
        self.target_ids = [] # A list of tensors, or a multi-dimension array

        token_ids = tokenizer.encode(txt)  # Tokenize the text and create token ids

        # Create input/target pairs using a sliding window
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i+max_length]
            target_chunk = token_ids[i+1:i+max_length+1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx) -> (Tensor, Tensor):
        # Returns a pair of 1D tensors (input_ids, target_ids) at the index idx
        return self.input_ids[idx], self.target_ids[idx]


In [17]:
def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True, num_workers=0):
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = GPTDataSetV1(txt, tokenizer, max_length, stride)

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)

In [18]:
# batch_size: number of batches in input tensors or target tensors
# max_length: number of token ids in a batch
# stride: shift across batches in input tensors
data_loader = create_dataloader_v1(raw_text, batch_size=2, max_length=4, stride=4, shuffle=False)

data_iter = iter(data_loader) # Iteration of the input/target tensor pairs
print(next(data_iter))

[tensor([[  40,  367, 2885, 1464],
        [1807, 3619,  402,  271]]), tensor([[  367,  2885,  1464,  1807],
        [ 3619,   402,   271, 10899]])]


# Creating Embeddings

## Input Token Embeddings

In [19]:
torch.manual_seed(123)

# Creating a Dataloader
max_length = 4
dataloader = create_dataloader_v1(
    raw_text, 
    batch_size=8, 
    max_length=max_length, 
    stride=max_length, 
    shuffle=False)
data_iter = iter(dataloader)

inputs, targets = next(data_iter)

print("Input Token IDs (per batch): \n", inputs)
print("\nInputs shape:\n", inputs.shape)

# Create an embedding layer
vocab_size = 50257
output_dim = 256

token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
print("\nEmbedding Layer:\n", token_embedding_layer)
input_embeddings = token_embedding_layer(inputs)
print("\nInput Embeddings Shape:\n", input_embeddings.shape)

Input Token IDs (per batch): 
 tensor([[   40,   367,  2885,  1464],
        [ 1807,  3619,   402,   271],
        [10899,  2138,   257,  7026],
        [15632,   438,  2016,   257],
        [  922,  5891,  1576,   438],
        [  568,   340,   373,   645],
        [ 1049,  5975,   284,   502],
        [  284,  3285,   326,    11]])

Inputs shape:
 torch.Size([8, 4])

Embedding Layer:
 Embedding(50257, 256)

Input Embeddings Shape:
 torch.Size([8, 4, 256])


## Input Positional Embeddings (Context Embeddings)

In [20]:
context_length = max_length # Positional context has the same number of input tokens in a batch
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)  # We only need the same number of positional embeddings as the number of tokens in a batch

pos_embeddings = pos_embedding_layer(torch.arange(context_length))  # torch.arrange creates a tensor with values from 0 to context_length-1
print(pos_embeddings.shape)

torch.Size([4, 256])


## Input Embeddings: Combine Input Token Embeddings and Positional Embeddings

In [21]:
input_embeddings = input_embeddings + pos_embeddings
print(input_embeddings.shape)

torch.Size([8, 4, 256])


# Self-Attention Mechanism

In [22]:
import torch

# Say, we have a text input sequence of 6 word and each word/token has 3 dimension embedding.
# So, the input sequence is of shape (6, 3)
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89], # word 1, Your (x^1)
        [0.55, 0.87, 0.66], # word 2, journey (x^2)
        [0.57, 0.85, 0.64], # word 3, starts (x^3)
        [0.22, 0.58, 0.33], # word 4, with (x^4)
        [0.77, 0.25, 0.10], # word 5, one (x^5)
        [0.05, 0.80, 0.55], # word 6, step (x^6)
    ]
)

## Simplified Self-Attention

### Step 1: Attention Scores

In [23]:
# Create a matrix of Attention Scores
attn_scores = torch.empty(inputs.shape[0], inputs.shape[0])

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)  # dot product of each word/token with each other

print("Attention scores matrix:\n", attn_scores)

# Or, Calculate the Attention Score matrix by using matrix multiplication
attn_scores = torch.mm(inputs, torch.t(inputs)) # Input matrix x Input matrix transpose

print("Attention scores matrix:\n", attn_scores)

# Or, a simplified presentation of matrix multiplication
attn_scores = inputs @ inputs.T

print("Attention scores matrix:\n", attn_scores)

Attention scores matrix:
 tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
Attention scores matrix:
 tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
Attention scores matrix:
 tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0

### Step 2: Attention Weights

In [24]:
attn_weights = torch.softmax(attn_scores, dim=1)
print("Attention weights:\n", attn_weights)

Attention weights:
 tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


### Step 3: Context Embeddings

In [25]:

## Calculate the Context Embedding
context_embeddings = attn_weights @ inputs
print("Context embeddings:\n", context_embeddings)

Context embeddings:
 tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## Scaled Dot-Product Attention

### Step 1: Query, Key, and Value Matrices

In [26]:
torch.manual_seed(123)
d_in = inputs.shape[1]   # This is dimension of the input token, M in the NxM
d_out = inputs.shape[1]

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [27]:
query = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value
print("Query:", query)
print("Key:", keys)
print("Value:", values)

Query: tensor([[0.3522, 0.3244, 0.4020],
        [0.8520, 0.4161, 1.0138],
        [0.8415, 0.4229, 0.9978],
        [0.5096, 0.1904, 0.6187],
        [0.4138, 0.4265, 0.4288],
        [0.6408, 0.1414, 0.8070]])
Key: tensor([[0.6813, 0.2706, 1.0793],
        [0.7305, 0.4227, 1.1993],
        [0.7355, 0.4227, 1.1901],
        [0.3363, 0.2225, 0.6077],
        [0.6184, 0.3038, 0.6909],
        [0.3178, 0.2383, 0.7426]])
Value: tensor([[0.4976, 0.9655, 0.7614],
        [0.9074, 1.3518, 1.5075],
        [0.8976, 1.3391, 1.4994],
        [0.5187, 0.7319, 0.8493],
        [0.4699, 0.7336, 0.9307],
        [0.6446, 0.9045, 0.9814]])


### Step 2: Attention Scores

In [28]:
attn_scores = query @ keys.T
print("Attention Score: ", attn_scores)

Attention Score:  tensor([[0.7616, 0.8765, 0.8746, 0.4349, 0.5941, 0.4877],
        [1.7872, 2.0141, 2.0091, 0.9952, 1.3538, 1.1227],
        [1.7646, 1.9901, 1.9852, 0.9834, 1.3383, 1.1091],
        [1.0664, 1.1947, 1.1916, 0.5897, 0.8004, 0.6667],
        [0.8601, 0.9968, 0.9950, 0.4947, 0.6817, 0.5516],
        [1.3458, 1.4957, 1.4915, 0.7374, 0.9968, 0.8366]])


### Step 3: Attention Weights

In [29]:
d_k = keys.shape[-1]
attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
print("Attention weights:", attn_weights)

Attention weights: tensor([[0.1747, 0.1866, 0.1864, 0.1446, 0.1586, 0.1491],
        [0.1862, 0.2123, 0.2117, 0.1179, 0.1450, 0.1269],
        [0.1859, 0.2118, 0.2112, 0.1184, 0.1454, 0.1273],
        [0.1798, 0.1936, 0.1932, 0.1365, 0.1542, 0.1427],
        [0.1751, 0.1895, 0.1893, 0.1418, 0.1579, 0.1465],
        [0.1837, 0.2003, 0.1998, 0.1293, 0.1501, 0.1369]])


### Step 4: Context Embeddings

In [30]:
context_embeddings = attn_weights @ values
print("Context embedding:", context_embeddings)

Context embedding: tensor([[0.6692, 1.0276, 1.1106],
        [0.6864, 1.0577, 1.1389],
        [0.6860, 1.0570, 1.1383],
        [0.6738, 1.0361, 1.1180],
        [0.6711, 1.0307, 1.1139],
        [0.6783, 1.0441, 1.1252]])


### Implementing a Scaled Dot-Product Attention Class

In [31]:
# Use nn.Parameter to create weight matrices
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super(SelfAttention_v1, self).__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_key = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_value = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

    def forward(self, inputs):
        query = inputs @ self.W_query
        key = inputs @ self.W_key
        value = inputs @ self.W_value

        attn_scores = query @ key.T
        attn_weights = torch.softmax(attn_scores / key.shape[-1]**0.5, dim=-1)

        context_embedding = attn_weights @ value
        return context_embedding

In [32]:
# Test the SelfAttention_v1 class
torch.manual_seed(123)

d_in = inputs.shape[1]
d_out = 2

self_attn_v1 = SelfAttention_v1(d_in, d_out)
context_embedding_v1 = self_attn_v1(inputs)
print("Context embeddings:\n", context_embedding_v1)

Context embeddings:
 tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


In [33]:
# Use nn.Linear to create weight matrices
import torch.nn as nn

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.attn_weights = None

    def forward(self, inputs):
        query = self.W_query(inputs)
        key = self.W_key(inputs)
        value = self.W_value(inputs)

        attn_scores = query @ key.T

        self.attn_weights = torch.softmax(attn_scores / key.shape[-1]**0.5, dim=-1)

        context_embedding = attn_weights @ value
        return context_embedding

In [34]:
# Test the SelfAttention_v2 class
torch.manual_seed(789)

d_in = inputs.shape[1]
d_out = 2

self_attn_v2 = SelfAttention_v2(d_in, d_out)
context_embedding_v2 = self_attn_v2(inputs)
print("Context embeddings:\n", context_embedding_v2)

Context embeddings:
 tensor([[-0.0776,  0.0699],
        [-0.0789,  0.0732],
        [-0.0789,  0.0732],
        [-0.0781,  0.0707],
        [-0.0775,  0.0706],
        [-0.0785,  0.0714]], grad_fn=<MmBackward0>)


## Causal Attention

### Step 1: Getting attention scores by using the same steps as the scaled dot-product attention

In [35]:
self_attn_v2 = SelfAttention_v2(d_in, d_out)

queries = self_attn_v2.W_query(inputs)
keys = self_attn_v2.W_key(inputs)

attn_scores = queries @ keys.T

print(attn_scores)

tensor([[0.2118, 0.1588, 0.1574, 0.0699, 0.0885, 0.0899],
        [0.2676, 0.2249, 0.2226, 0.1051, 0.1195, 0.1361],
        [0.2622, 0.2215, 0.2193, 0.1038, 0.1175, 0.1344],
        [0.1496, 0.1257, 0.1244, 0.0587, 0.0668, 0.0760],
        [0.0926, 0.0984, 0.0972, 0.0506, 0.0479, 0.0662],
        [0.2108, 0.1664, 0.1649, 0.0754, 0.0907, 0.0973]],
       grad_fn=<MmBackward0>)


### Step 2: Masking Attention Score

In [36]:
length = attn_weights.shape[0]

mask = torch.triu(torch.ones(length, length), diagonal=1)
masked_attn_score = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked_attn_score)

tensor([[0.2118,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.2676, 0.2249,   -inf,   -inf,   -inf,   -inf],
        [0.2622, 0.2215, 0.2193,   -inf,   -inf,   -inf],
        [0.1496, 0.1257, 0.1244, 0.0587,   -inf,   -inf],
        [0.0926, 0.0984, 0.0972, 0.0506, 0.0479,   -inf],
        [0.2108, 0.1664, 0.1649, 0.0754, 0.0907, 0.0973]],
       grad_fn=<MaskedFillBackward0>)


### Step 3: Normalizing to Attention Weights

In [37]:
attn_weights = torch.softmax(masked_attn_score / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5075, 0.4925, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3399, 0.3303, 0.3298, 0.0000, 0.0000, 0.0000],
        [0.2562, 0.2519, 0.2517, 0.2402, 0.0000, 0.0000],
        [0.2021, 0.2030, 0.2028, 0.1962, 0.1959, 0.0000],
        [0.1758, 0.1704, 0.1702, 0.1598, 0.1615, 0.1623]],
       grad_fn=<SoftmaxBackward0>)


### Step 4: Add Dropouts

In [38]:
torch.manual_seed(123)

dropout = torch.nn.Dropout(0.5)
dropped_out_attn_weights = dropout(attn_weights)
print(dropped_out_attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6799, 0.6606, 0.6595, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5038, 0.5033, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4060, 0.0000, 0.3924, 0.0000, 0.0000],
        [0.0000, 0.3408, 0.3404, 0.3196, 0.3230, 0.0000]],
       grad_fn=<MulBackward0>)


### Implementing a Causal Attention Class

In [39]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        
        self.d_out = d_out
        self.d_in = d_in
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        self.drop_out = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length),diagonal=1))
    
    def forward(self, x):
        # x is a 3-dimensional tensor
        b, num_tokens, d_in = x.shape
        
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.transpose(1,2)
        # Mask attention score
        attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.drop_out(attn_weights)
        
        context_embedding = attn_weights @ values
        
        return context_embedding

In [40]:
batch = torch.stack((inputs, inputs), dim=0) # Create a batch of input tokens (2 x (6 x 3))
print(batch.shape)

context_length = batch.shape[1]

ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)

print(context_vecs.shape)

torch.Size([2, 6, 3])
torch.Size([2, 6, 2])


## Multi-head Attention

In [41]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        """
        Class of Multi-head Attention
        :param d_in: dimension of input embedding
        :param d_out: dimension of context embedding
        :param context_length: dimension of attention weights, which is the number of inputs, d_in
        :param dropout: 
        :param num_heads: 
        :param qkv_bias: 
        """
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # dimension per head
        
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        self.out_proj = nn.Linear(d_out, d_out)  # the final concatenation of context embedding

        # Define masking
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )
        
    def forward(self, x):
        """
        :param x: [n, [N, M]] n batches of input sequence (N x M)
        """

        b, num_tokens, d_in = x.shape  # b: the number of batches
        
        print(f"input shape: {x.shape}")
        
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        print(f"keys shape: {keys.shape}")
        
        # Splits query, key, value matrices in
        # to several parts by d_out
        # self.num_heads x self.head_dim = d_out
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(
            b, num_tokens, self.num_heads, self.head_dim
        )
        print(f"keys shape after split: {keys.shape}")
        
        # Transposes from shape (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        print(f"keys shape after transpose(1,2): {keys.shape}")

        # (b, num_heads, head_dim, num_tokens)
        attn_scores = queries @ keys.transpose(2, 3)
        print(f"attention score shape: {attn_scores.shape}")

        # Causal masking
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        
        # Dropout attn weights
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine views
        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )

        context_vec = self.out_proj(context_vec)
        print(f"context vec shape: {context_vec.shape}")
        
        return context_vec

In [42]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

input shape: torch.Size([2, 6, 3])
keys shape: torch.Size([2, 6, 2])
keys shape after split: torch.Size([2, 6, 2, 1])
keys shape after transpose(1,2): torch.Size([2, 2, 6, 1])
attention score shape: torch.Size([2, 2, 6, 6])
context vec shape: torch.Size([2, 6, 2])
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


In [44]:
# The shape of this tensor is (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4).
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

print(a.shape)

print(a.transpose(1,2))

print(a.transpose(1,2).transpose(2, 3))

print(a @ a.transpose(2, 3))

print(a.view(1,2,12))

torch.Size([1, 2, 3, 4])
tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
          [0.0772, 0.3565, 0.1479, 0.5331]],

         [[0.8993, 0.0390, 0.9268, 0.7388],
          [0.4066, 0.2318, 0.4545, 0.9737]],

         [[0.7179, 0.7058, 0.9156, 0.4340],
          [0.4606, 0.5159, 0.4220, 0.5786]]]])
tensor([[[[0.2745, 0.0772],
          [0.6584, 0.3565],
          [0.2775, 0.1479],
          [0.8573, 0.5331]],

         [[0.8993, 0.4066],
          [0.0390, 0.2318],
          [0.9268, 0.4545],
          [0.7388, 0.9737]],

         [[0.7179, 0.4606],
          [0.7058, 0.5159],
          [0.9156, 0.4220],
          [0.4340, 0.5786]]]])
tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])
tensor([[[0.2745, 0.6584, 0.2775, 0.8573, 0.8993, 0.0390, 0.9268, 0.7388,
          0.7179, 0.7058, 0.9156, 0.4340],
         [0.0772, 0.356