# Tokenization and Data Preparation

This notebook provides an interactive guide to understanding this component of GPT.


In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add project root to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath('')))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import our tokenizer and dataset classes
from src.data.tokenizer import get_tokenizer
from src.data.dataset import GPTDataset
import tiktoken

## Understanding Tokenization

Tokenization is the process of converting raw text into numerical tokens that the model can process. We use the GPT-2 tokenizer (via tiktoken) which uses Byte Pair Encoding (BPE).

In [None]:
# Get the tokenizer
tokenizer = get_tokenizer("gpt2")
print(f"Vocabulary size: {tokenizer.n_vocab}")

In [None]:
# Example: Tokenize a simple sentence
text = "Hello, world! This is GPT from scratch."
tokens = tokenizer.encode(text)
print(f"Text: {text}")
print(f"Tokens: {tokens}")
print(f"Number of tokens: {len(tokens)}")

# Decode back to text
decoded = tokenizer.decode(tokens)
print(f"Decoded: {decoded}")
print(f"Round-trip successful: {decoded == text}")

In [None]:
# Visualize tokenization: see how words are split
example_texts = [
    "Hello",
    "Hello world",
    "The quick brown fox",
    "GPT-2 uses Byte Pair Encoding",
    "tokenization"
]

for text in example_texts:
    tokens = tokenizer.encode(text)
    token_strings = [tokenizer.decode([t]) for t in tokens]
    print(f"'{text}' â†’ {tokens}")
    print(f"  Tokens: {token_strings}")
    print()

## Creating a Dataset

The `GPTDataset` class creates training sequences using a sliding window approach. For each position, we create input-target pairs where the target is shifted by one position.

In [None]:
# Load sample text
sample_text_path = os.path.join(project_root, "data", "sample_text.txt")
if os.path.exists(sample_text_path):
    with open(sample_text_path, "r", encoding="utf-8") as f:
        text = f.read()
    print(f"Loaded text: {len(text)} characters")
    print(f"First 200 characters: {text[:200]}...")
else:
    # Fallback to a simple example
    text = "Once upon a time, there was a little girl named Emma. She loved to play in the garden."
    print("Using fallback text")

In [None]:
# Create dataset with sliding window
context_length = 16  # Maximum sequence length
stride = 8  # Step size for sliding window (50% overlap)

dataset = GPTDataset(
    text=text,
    tokenizer=tokenizer,
    maximum_length=context_length,
    stride=stride
)

print(f"Dataset size: {len(dataset)} sequences")
print(f"Context length: {context_length}")
print(f"Stride: {stride}")

In [None]:
# Examine a few examples
print("Example sequences from the dataset:\n")
for i in range(min(3, len(dataset))):
    input_ids, target_ids = dataset[i]
    
    # Decode to see the text
    input_text = tokenizer.decode(input_ids.tolist())
    target_text = tokenizer.decode(target_ids.tolist())
    
    print(f"Sequence {i}:")
    print(f"  Input:  {input_text}")
    print(f"  Target: {target_text}")
    print(f"  Input tokens:  {input_ids.tolist()}")
    print(f"  Target tokens: {target_ids.tolist()}")
    print()

In [None]:
# Verify that targets are shifted by one position
input_ids, target_ids = dataset[0]
print("Verifying target shift:")
print(f"Input:  {input_ids.tolist()}")
print(f"Target: {target_ids.tolist()}")
print(f"\nTarget should be input shifted by 1:")
print(f"Input[1:] == Target[:-1]: {(input_ids[1:] == target_ids[:-1]).all().item()}")
print(f"Target[-1] is the next token after input[-1]")

## Visualizing Dataset Statistics

In [None]:
# Analyze token distribution
all_tokens = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
unique_tokens = len(set(all_tokens))
vocab_usage = unique_tokens / tokenizer.n_vocab * 100

print(f"Text statistics:")
print(f"  Total tokens: {len(all_tokens):,}")
print(f"  Unique tokens: {unique_tokens:,}")
print(f"  Vocabulary usage: {vocab_usage:.2f}%")
print(f"  Average tokens per character: {len(all_tokens) / len(text):.2f}")

In [None]:
# Plot token frequency distribution
token_counts = {}
for token in all_tokens:
    token_counts[token] = token_counts.get(token, 0) + 1

# Get top 20 most frequent tokens
top_tokens = sorted(token_counts.items(), key=lambda x: x[1], reverse=True)[:20]
token_ids, counts = zip(*top_tokens)
token_strings = [tokenizer.decode([tid]) for tid in token_ids]

plt.figure(figsize=(12, 6))
plt.bar(range(len(token_strings)), counts)
plt.xticks(range(len(token_strings)), token_strings, rotation=45, ha='right')
plt.xlabel('Token')
plt.ylabel('Frequency')
plt.title('Top 20 Most Frequent Tokens')
plt.tight_layout()
plt.show()

## Testing with PyTorch DataLoader

Let's see how the dataset works with PyTorch's DataLoader for batching.

In [None]:
from torch.utils.data import DataLoader

# Create a DataLoader
batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Get a batch
for batch_idx, (input_batch, target_batch) in enumerate(dataloader):
    print(f"Batch {batch_idx}:")
    print(f"  Input shape: {input_batch.shape}")  # [batch_size, sequence_length]
    print(f"  Target shape: {target_batch.shape}")  # [batch_size, sequence_length]
    print(f"  Data types: input={input_batch.dtype}, target={target_batch.dtype}")
    
    # Show first sequence in batch
    first_input = input_batch[0].tolist()
    first_target = target_batch[0].tolist()
    print(f"  First sequence input: {tokenizer.decode(first_input)}")
    print(f"  First sequence target: {tokenizer.decode(first_target)}")
    print()
    
    if batch_idx >= 1:  # Just show first 2 batches
        break