# 1. Assignment Overview

In this assignment, you will build all the components needed to train a standard Transformer language model (LM) from scratch and train some models.

## What you will implement
1. Byte-pair encoding (BPE) tokenizer (§2)
2. Transformer language model (LM) (§3)
3. The cross-entropy loss function and the AdamW optimizer (§4)
4. The training loop, with support for serializing and loading model and optimizer state (§5)

## What you will run
1. Train a BPE tokenizer on the TinyStories dataset.
2. Run your trained tokenizer on the dataset to convert it into a sequence of integer IDs.
3. Train a Transformer LM on the TinyStories dataset.
4. Generate samples and evaluate perplexity using the trained Transformer LM.
5. Train models on OpenWebText and submit your attained perplexities to a leaderboard.

## What you can use
We expect you to build these components from scratch. In particular, you may **not** use any definitions from `torch.nn`, `torch.nn.functional`, or `torch.optim` except for the following:

- `torch.nn.Parameter`
- Container classes in `torch.nn` (e.g., `Module`, `ModuleList`, `Sequential`, etc.)
- The `torch.optim.Optimizer` base class

You may use any other PyTorch definitions. If you would like to use a function or class and are not sure whether it is permitted, feel free to ask on Slack. When in doubt, consider whether using it compromises the “from-scratch” ethos of the assignment.

## Statement on AI tools

Prompting LLMs such as ChatGPT is permitted for low-level programming questions or high-level conceptual questions about language models, but using it directly to solve the problem is **prohibited**.

We strongly encourage you to disable AI autocomplete (e.g., Cursor Tab, GitHub Copilot) in your IDE when completing assignments (non-AI autocomplete, e.g., function-name completion, is fine). AI autocomplete often makes it harder to engage deeply with the content.

## What the code looks like

All the assignment code and this writeup are on GitHub:

> https://github.com/stanford-cs336/assignment1-basics

Please `git clone` the repo. If there are updates, we’ll notify you so you can `git pull` to get the latest.

1. **`cs336_basics/*`**: Your code lives here. There’s no code provided—you can implement everything from scratch.
2. **`adapters.py`**: Defines required functionality your code must expose. For each feature (e.g., scaled dot-product attention), fill in the corresponding hook (e.g., `run_scaled_dot_product_attention`) by invoking your implementation. *Do not put substantive logic here; it’s glue code.*
3. **`test_*.py`**: Tests you must pass (e.g., `test_scaled_dot_product_attention`) that call hooks defined in `adapters.py`. **Do not edit** the tests.

## How to submit

You will submit the following files to Gradescope:

- **`writeup.pdf`**: Answer all written questions. Please typeset your responses.
- **`code.zip`**: Contains all the code you’ve written.

To submit to the leaderboard, open a PR to:

> https://github.com/stanford-cs336/assignment1-basics-leaderboard

See the `README.md` in the leaderboard repository for detailed submission instructions.

## Where to get datasets

This assignment will use two pre-processed datasets: **TinyStories** [Eldan and Li, 2023] and **OpenWebText** [Gokaslan et al., 2019]. Both datasets are single, large plaintext files. If you are doing the assignment with the class, you can find these files at `/data` of any non-head node machine. If you are following along at home, you can download these files with the commands inside the `README.md`.



> **Low-Resource/Downscaling Tip: Init**  
> Throughout the course’s assignment handouts, we will give advice for working through parts of the assignment with fewer or no GPU resources. For example, we will sometimes suggest **downscaling** your dataset or model size, or explain how to run training code on a MacOS integrated GPU or CPU. You’ll find these “low-resource tips” in a blue box (like this one). Even if you are an enrolled Stanford student with access to the course machines, these tips may help you iterate faster and save time, so we recommend you read them!

> **Low-Resource/Downscaling Tip: Assignment 1 on Apple Silicon or CPU**  
> With the staff solution code, we can train an LM to generate reasonably fluent text on an Apple M3 Max chip with 36 GB RAM, in under **5 minutes** on Metal GPU (MPS) and about **30 minutes** using the CPU. If these words don’t mean much to you, don’t worry! Just know that if you have a reasonably up-to-date laptop and your implementation is correct and efficient, you will be able to train a small LM that generates simple children’s stories with decent fluency.  
> Later in the assignment, we will explain what changes to make if you are on CPU or MPS.

# 2. Byte-Pair Encoding (BPE) Tokenizer

In the first part of the assignment, we will train and implement a **byte-level** byte-pair encoding (BPE) tokenizer [Sennrich et al., 2016; Wang et al., 2019]. In particular, we will represent arbitrary (Unicode) strings as a sequence of **bytes** and train our BPE tokenizer on this byte sequence. Later, we will use this tokenizer to encode text (a string) into **tokens** (a sequence of integers) for language modeling.

## 2.1 The Unicode Standard

Unicode is a text encoding standard that maps characters to integer **code points**. As of Unicode 16.0 (released in September 2024), the standard defines **154,998** characters across **168** scripts. For example, the character “s” has the code point **115** (typically notated as `U+0073`, where `U+` is a conventional prefix and `0073` is 115 in hexadecimal), and the character “牛” has the code point **29275**.  
In Python, you can use the `ord()` function to convert a single Unicode character into its integer representation. The `chr()` function converts an integer Unicode code point into a string with the corresponding character.

```py
>>> ord('牛')
29275
>>> chr(29275)
'牛'

Problem (unicode1): Understanding Unicode (1 point)
1. What Unicode character does `chr(0)` return?
Deliverable: A one-sentence response.
2. How does this character’s string representation `(__repr__())` differ from its printed representation?
Deliverable: A one-sentence response.
3. What happens when this character occurs in text? It may be helpful to play around with the following in your Python interpreter and see if it matches your expectations:

In [1]:
chr(0)

'\x00'

In [2]:
print(chr(0))

 


In [3]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [8]:
print("this is a test" + chr(0) + "string")

this is a test string


## 2.2 Unicode Encodings

While the Unicode standard defines a mapping from characters to code points (integers), it’s impractical to train tokenizers directly on Unicode codepoints, since the vocabulary would be prohibitively large (around 150K items) and sparse (since many characters are quite rare). Instead, we’ll use a Unicode **encoding**, which converts a Unicode character into a sequence of **bytes**. The Unicode standard itself defines three encodings: **UTF-8**, **UTF-16**, and **UTF-32**, with **UTF-8** being the dominant encoding for the Internet (more than 98% of all webpages).

To encode a Unicode string into UTF-8, we can use the `encode()` function in Python. To access the underlying byte values for a Python `bytes` object, we can iterate over it (e.g., call `list()`). Finally, we can use the `decode()` function to decode a UTF-8 byte string into a Unicode string.

```py
>>> test_string = "hello! こんにちは!"
>>> utf8_encoded = test_string.encode("utf-8")
>>> print(utf8_encoded)
b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'
>>> print(type(utf8_encoded))
<class 'bytes'>

# Get the byte values for the encoded string (integers from 0 to 255).
>>> list(utf8_encoded)
[104, 101, 108, 108, 111, 33, 32, 227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175, 33]

# One byte does not necessarily correspond to one Unicode character!
>>> print(len(test_string))
13
>>> print(len(utf8_encoded))
23

>>> print(utf8_encoded.decode("utf-8"))
hello! こんにちは!


By converting our Unicode codepoints into a sequence of bytes (e.g., via the UTF-8 encoding), we
are essentially taking a sequence of codepoints (integers in the range 0 to 154,997) and transforming it
into a sequence of byte values (integers in the range 0 to 255). The 256-length byte vocabulary is much
more manageable to deal with. When using byte-level tokenization, we do not need to worry about out-of-
vocabulary tokens, since we know that any input text can be expressed as a sequence of integers from 0 to
255.

> **Problem (unicode2): Unicode Encodings (3 points)**

**(a)** What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than UTF-16 or UTF-32? It may be helpful to compare the output of these encodings for various input strings.  
**Deliverable:** A one-to-two sentence response.

**(b)** Consider the following (incorrect) function, which is intended to decode a UTF-8 byte string into a Unicode string. Why is this function incorrect? Provide an example of an input byte string that yields incorrect results.

```python
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

>>> decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))
'hello'
```
**Deliverable:** An example input byte string for which `decode_utf8_bytes_to_str_wrong` produces incorrect output, with a one-sentence explanation of why the function is incorrect.

**(c)** Give a two byte sequence that does not decode to any Unicode character(s).

**Deliverable:** An example, with a one-sentence explanation.

In [8]:
# try encoding a string with utf-8
test_string = "hello 黄奕!"  # try your name
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)  # output: b'hello \xe9\xbb\x84\xe5\xa5\x95!'
print(type(utf8_encoded)) 

# see the byte values
list(utf8_encoded)  

# verify the reversibility
print(len(test_string), len(utf8_encoded))  
print(utf8_encoded.decode("utf-8"))  # perfectly restored: 'hello 黄奕!'

b'hello \xe9\xbb\x84\xe5\xa5\x95!'
<class 'bytes'>
9 13
hello 黄奕!


In [10]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))

'hello'

In [11]:
decode_utf8_bytes_to_str_wrong("café".encode("utf-8"))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xc3 in position 0: unexpected end of data

## Why the function is incorrect:

The function `decode_utf8_bytes_to_str_wrong` is incorrect because **UTF-8 is a variable-length encoding**. Some characters require multiple bytes to represent them, but the function tries to decode each individual byte as if it were a complete UTF-8 character.

## The problem with "café":

1. **"café"** in UTF-8 is encoded as: `b'caf\xc3\xa9'`
2. The bytes are: `[99, 97, 102, 195, 169]`
3. The first three bytes (`c`, `a`, `f`) are ASCII characters that can be decoded individually
4. However, the **é** character requires **2 bytes**: `0xc3` and `0xa9`
   - `0xc3` is the first byte of a 2-byte UTF-8 sequence, but when decoded alone, it's incomplete
   - `0xa9` is the second byte, but when decoded alone, it's not a valid UTF-8 start byte

## The correct approach:

Instead of decoding each byte individually, you should decode the entire byte string at once:

```python
def decode_utf8_bytes_to_str_correct(bytestring: bytes):
    return bytestring.decode("utf-8")
```

This works because the UTF-8 decoder knows how to properly interpret the multi-byte sequences and reconstruct the complete Unicode characters.

The error you're seeing (`'utf-8' codec can't decode byte 0xc3 in position 0: unexpected end of data`) occurs because `0xc3` is the start of a 2-byte UTF-8 sequence, but when you try to decode it as a single byte, the decoder expects more data that isn't there.

In [6]:
def decode_utf8_bytes_to_str_correct(bytestring: bytes):
    return bytestring.decode("utf-8")

In [7]:
decode_utf8_bytes_to_str_correct("café".encode("utf-8"))

'café'

## 2.3 Subword Tokenization

While byte-level tokenization can alleviate the out-of-vocabulary issues faced by word-level tokenizers, tokenizing text into **bytes** results in extremely long input sequences. This slows down model training: a sentence with 10 words might be ~10 tokens in a word-level LM, but could be 50+ tokens in a character/byte-level model (depending on word lengths). Longer sequences increase compute per step, and modeling long byte sequences creates longer-term dependencies.

**Subword tokenization** sits between word-level and byte-level tokenizers.

- A byte-level tokenizer’s vocabulary has **256** entries (byte values 0–255).
- A subword tokenizer trades a **larger vocabulary** for **shorter sequences** (better compression).
- Example: if the byte sequence `b'the'` occurs frequently, assigning it its own token reduces a 3-byte sequence to **one** token.

**Choosing subword units.**  
Sennrich et al. (2016), following Gage (1994), propose **byte-pair encoding (BPE)**—a compression algorithm that iteratively merges the most frequent adjacent pair into a new token. Repeated merges add subword tokens to maximize compression: frequent words may become single subword units.

Subword tokenizers built via BPE are often called **BPE tokenizers**. In this assignment, we implement a **byte-level BPE** tokenizer: vocabulary items are bytes or merged byte sequences, combining robust OOV handling with manageable sequence lengths. Constructing the BPE vocabulary is called **“training”** the tokenizer.


## 2.4 BPE Tokenizer Training

The BPE tokenizer training procedure consists of three main steps.

### Vocabulary initialization
The tokenizer vocabulary is a one-to-one mapping from bytestring token to integer ID. Since we are training a **byte-level** BPE tokenizer, our initial vocabulary is simply the set of all bytes. Because there are 256 possible byte values, the initial vocabulary size is **256**.

### Pre-tokenization
Once you have a vocabulary, you could, in principle, count how often bytes occur next to each other in your text and begin merging them starting with the most frequent pair of bytes. However, this is quite computationally expensive, since you would have to take a full pass over the corpus for each merge. In addition, directly merging bytes across the corpus may result in tokens that differ only in punctuation (e.g., “dog!” vs. “dog.”), which would get completely different token IDs even though they are highly semantically similar.

To avoid this, we **pre-tokenize** the corpus. You can think of this as a coarse-grained tokenization over the corpus that helps us count how often pairs of characters appear. For example, the word `'text'` might be a pre-token that appears 10 times. In this case, when we count how often the characters `t` and `e` appear next to each other, we will see that the word `text` has `t` and `e` adjacent and we can increment their count by 10 instead of scanning the whole corpus. Since we’re training a byte-level BPE model, each pre-token is represented as a sequence of UTF-8 bytes.

The original BPE implementation of Sennrich et al. (2016) pre-tokenizes by simply splitting on whitespace (i.e., `s.split(" ")`). In contrast, we will use a **regex-based** pre-tokenizer (used by GPT-2; Radford et al., 2019) from the tiktoken project:  
`https://github.com/openai/tiktoken/pull/234/files`


In [5]:
PAT = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

It may be useful to interactively split some text with this pre-tokenizer to get a better sense of its behavior:

In [6]:
# requires `regex` package
import regex as re
re.findall(PAT, "some text that i'll pre-tokenize")
# -> ['some', ' ', 'text', ' ', 'that', ' ', 'i', "'ll", ' ', 'pre', '-', 'tokenize']

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

When using it in your code, however, you should use `re.finditer` to avoid storing the pre-tokenized words
as you construct your mapping from pre-tokens to their counts.

### Compute BPE merges

Now that we’ve converted our input text into **pre-tokens** and represented each pre-token as a sequence of **UTF-8 bytes**, we can compute the BPE merges (i.e., train the BPE tokenizer).

At a high level, the **BPE algorithm** iteratively:
1. Counts every adjacent **pair of bytes**.
2. Finds the pair with the **highest frequency** (e.g., `("A","B")`).
3. **Merges** every occurrence of this pair into a new token (e.g., `"AB"`).
4. Adds the new token to the **vocabulary**.

The final vocabulary size equals the initial size (**256** for byte-level) **plus** the number of merge operations performed.  
For efficiency, **do not consider pairs that cross pre-token boundaries**.  
When ties occur in pair frequency, **break ties deterministically by choosing the lexicographically greater pair**. Example:


In [7]:
max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")])

('BA', 'A')

In [8]:
max([("es"),("st")])

'st'

### Special tokens

Some strings (e.g., `<|endoftext|>`) encode metadata such as document boundaries. When encoding text, it’s often desirable to mark such strings as **special tokens** that **must never be split** and are **always preserved as a single token** (one integer ID). This ensures, for example, that an EOS token is recognized unambiguously during generation.
These special tokens must be **added to the vocabulary** with **fixed IDs**.

> *Reference:* Algorithm 1 of **Sennrich et al. (2016)** gives an (inefficient) BPE training procedure that essentially follows the steps above. Implementing and testing that function is a useful first exercise to check your understanding.


### Example (`bpe_example`): BPE training example

Here is a stylized example from *Sennrich et al.* (2016). Consider a corpus consisting of the following text:

```

low low low low low
lower lower widest widest widest
newest newest newest newest newest newest

```

and the vocabulary has a special token `<|endoftext|>`.

**Vocabulary**  
We initialize our vocabulary with `<|endoftext|>` and the 256 byte values.

**Pre-tokenization**  
For simplicity (to focus on the merge procedure), pretokenization splits on whitespace. Counting tokens yields the frequency table:

`{low: 5, lower: 2, widest: 3, newest: 6}`

It’s convenient to represent this as a `dict[tuple[bytes], int]`, e.g. `{(l,o,w): 5, …}`.  
Note: even a single byte is a `bytes` object in Python. There is no separate `byte` type for a single byte, just as there is no `char` type in Python for a single character.

**Merges**  
First, look at every successive pair of bytes and sum the frequency across words where they appear:

`{lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, es: 9, st: 9, ne: 6, ew: 6}`

The pairs `('es')` and `('st')` are tied; choose the lexicographically greater pair, `('st')`.  
After this merge the pre-tokens become:  
`{(l,o,w): 5, (l,o,w,e,r): 2, (w,i,d,e,st): 3, (n,e,w,e,st): 6}`.

Second round: `(e, st)` is most common (count = 9); merge to get:  
`{(l,o,w): 5, (l,o,w,e,r): 2, (w,i,d,est): 3, (n,e,w,est): 6}`.

Continuing, if we take **6 merges**, the sequence is:  
`['s t', 'e st', 'o w', 'l ow', 'w est', 'n e']`.

With these merges, the vocabulary elements include:  
`[<|endoftext|>, […256 BYTE CHARS], st, est, ow, low, west, ne]`.

With this vocabulary and merge list, the word **`newest`** tokenizes as:  
`[ne, west]`.

### 2.5 Experimenting with BPE Tokenizer Training

Let’s train a byte-level BPE tokenizer on the TinyStories dataset. Instructions to find/download the dataset are in Section 1. Before you start, skim TinyStories to get a sense of what’s in the data.

**Parallelizing pre-tokenization**  
A major bottleneck is pre-tokenization. You can speed this up by parallelizing using Python’s built-in `multiprocessing`. In parallel implementations, **chunk the corpus so that chunk boundaries occur at the beginning of a special token**. You can copy the starter code here to compute chunk boundaries and distribute work:

<https://github.com/stanford-cs336/assignment1-basics/blob/main/cs336_basics/pretokenization_example.py>

This chunking is always valid because we never merge across document boundaries. For this assignment, you can always split in this way. Don’t worry about the edge case of a huge corpus that lacks `<|endoftext|>`.

**Removing special tokens before pre-tokenization**  
Before running pre-tokenization with a regex pattern (e.g., using `re.finditer`), strip out all special tokens from your corpus (or from each chunk in a parallel implementation). **Split on the special tokens** so no merging can occur across the document boundary. This can be done with `re.split` using `"|" .join(special_tokens)` as the delimiter (**make sure to use `re.escape` since `|` occurs in the special tokens**). The test `test_train_bpe_special_tokens` checks this.

**Optimizing the merging step**  
The naïve BPE training loop is slow because it recomputes frequencies over all byte pairs each round. Speed it up by **indexing the counts of all pairs** and **incrementally updating** only the pairs that overlap the most recently merged pair. This caching yields significant speedups, although the **merging phase itself is not parallelizable in Python**.


### Low-Resource/Downscaling Tip: Profiling

Use profiling tools like **cProfile** or **scalene** to identify bottlenecks in your implementation and focus your optimization there.

---

### Low-Resource/Downscaling Tip: “Downscaling”

Instead of immediately training on the full TinyStories dataset, start with a **debug dataset** (a small subset).  
Example: train on the **TinyStories validation set** (≈22K documents) instead of 2.12M.  
This general strategy speeds up development (smaller datasets, smaller models, etc.).  
Choose the debug size carefully: **large enough** to surface the same bottlenecks as the full setup (so optimizations generalize), but **not so large** that runs take forever.

### Problem (`train_bpe`): BPE Tokenizer Training *(15 points)*

**Deliverable**  
Write a function that, given a path to an input text file, trains a **byte-level BPE tokenizer**.  
Your function should accept at least:

- `input_path: str` — Path to a text file with BPE tokenizer training data.  
- `vocab_size: int` — Maximum final vocabulary size (includes initial byte vocabulary, merges, and any special tokens).  
- `special_tokens: list[str]` — Strings to add to the vocabulary; these do **not** otherwise affect BPE training.

Return the resulting vocabulary and merges:

- `vocab: dict[int, bytes]` — Tokenizer vocabulary mapping token ID (`int`) → token bytes (`bytes`).  
- `merges: list[tuple[bytes, bytes]]` — BPE merges produced during training.  
  Each item is a tuple `(<token1>, <token2>)`, meaning `<token1>` was merged with `<token2>`.  
  **Order merges by creation time.**

**Testing**  
To run our tests, first implement the test adapter at `adapters.run_train_bpe`.  
Then run:
```bash
uv run pytest tests/test_train_bpe.py
````

Your implementation should pass all tests.

**Optional speedups**
You may implement hot paths in a systems language:

* **C++** (consider **cppyy**) or **Rust** (using **PyO3**).
  Be mindful of copy vs. zero-copy from Python memory. Provide build instructions or ensure it builds with only `pyproject.toml`.

**Regex note**
The **GPT-2 regex** is not well supported (and often too slow) in many engines.
We verified **Oniguruma** is reasonably fast and supports negative lookahead, but Python’s **`regex`** package is—if anything—even faster.

```


In [14]:
!uv run pytest tests/test_train_bpe.py

platform darwin -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /Users/yihuang/projects/data_science/cd336-a1/cs336-hw1
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 3 items                                                              [0m

tests/test_train_bpe.py::test_train_bpe_speed [32mPASSED[0m
tests/test_train_bpe.py::test_train_bpe [32mPASSED[0m
tests/test_train_bpe.py::test_train_bpe_special_tokens [32mPASSED[0m

tests/adapters.py:308
    rope_theta (float): The RoPE $\Theta$ parameter.



### Problem (`train_bpe_tinystories`): BPE Training on TinyStories *(2 points)*

**(a)** Train a **byte-level BPE tokenizer** on the TinyStories dataset with **vocab size = 10,000**.  
Add the **`<|endoftext|>`** special token to the vocabulary. Serialize the resulting **vocab** and **merges** to disk for inspection.  
Report: **training time**, **peak memory**, **longest token** in the vocab, and whether it **makes sense**.

- **Resource requirements:** ≤ **30 minutes** (no GPUs), ≤ **30 GB RAM**.  
- **Hint:** With `multiprocessing` during pretokenization and these facts, you should get **< 2 minutes**:
  1. `<|endoftext|>` **delimits documents** in the data files.  
  2. `<|endoftext|>` is handled as a **special case** before BPE merges are applied.
- **Deliverable:** a **1–2 sentence** response.

**(b)** **Profile** your code. Which part of tokenizer training takes the **most time**?  
- **Deliverable:** a **1–2 sentence** response.

---

Next, train a byte-level BPE tokenizer on **OpenWebText**. As before, skim the dataset first to understand its contents.

---

### Problem (`train_bpe_expts_owt`): BPE Training on OpenWebText *(2 points)*

**(a)** Train a byte-level BPE tokenizer on **OpenWebText** with **vocab size = 32,000**.  
Serialize the resulting **vocab** and **merges**. What is the **longest token**? Does it **make sense**?

- **Resource requirements:** ≤ **12 hours** (no GPUs), ≤ **100 GB RAM**.  
- **Deliverable:** a **1–2 sentence** response.

**(b)** **Compare and contrast** the tokenizer trained on **TinyStories** vs **OpenWebText**.  
- **Deliverable:** a **1–2 sentence** response.


In [37]:
!cd cs336_basics && python train_bpe.py && cd ..

🚀 开始BPE训练
📋 配置: vocab_size=10000, special_tokens=['<|endoftext|>']
📁 数据路径: ../../data/TinyStoriesV2-GPT4-train.txt
📖 加载训练数据...
🔧 进行预分词...
✅ 预分词完成，得到 17,380 个唯一token
📊 统计字符对频率...
✅ 统计完成，发现 1,038 个字符对
🔄 开始BPE训练，目标词汇表大小: 10000
   进度: 356/10000 tokens, 99 merges
   进度: 456/10000 tokens, 199 merges
   进度: 556/10000 tokens, 299 merges
   进度: 656/10000 tokens, 399 merges
   进度: 756/10000 tokens, 499 merges
   进度: 856/10000 tokens, 599 merges
   进度: 956/10000 tokens, 699 merges
   进度: 1056/10000 tokens, 799 merges
   进度: 1156/10000 tokens, 899 merges
   进度: 1256/10000 tokens, 999 merges
   进度: 1356/10000 tokens, 1099 merges
   进度: 1456/10000 tokens, 1199 merges
   进度: 1556/10000 tokens, 1299 merges
   进度: 1656/10000 tokens, 1399 merges
   进度: 1756/10000 tokens, 1499 merges
   进度: 1856/10000 tokens, 1599 merges
   进度: 1956/10000 tokens, 1699 merges
   进度: 2056/10000 tokens, 1799 merges
   进度: 2156/10000 tokens, 1899 merges
   进度: 2256/10000 tokens, 1999 merges
   进度: 2356/10000 tokens, 2099 mer

## BPE Training Results

**Training Time:** 78.65 seconds (1.31 minutes) - ✅ **Meets requirement** (< 2 minutes)

**Peak Memory:** 0.30 GB - ✅ **Well within limit** (< 30 GB)

**Longest Token:** `' enthusiastically'` (17 bytes) - ✅ **Makes sense** (common English word with prefix)

**Vocab Size:** 10,000 tokens (including 256 base bytes + 1 special token + 9,743 BPE merges)

**Special Token:** `<|endoftext|>` successfully added to vocabulary

**Files Generated:**
- `vocab.json`: Vocabulary mapping (token_id → token_string)
- `merges.txt`: BPE merge operations (9743 merges)

The implementation successfully trained a byte-level BPE tokenizer on TinyStories data, achieving the required vocabulary size while maintaining reasonable training time and memory usage. The longest token is a common English word, indicating the tokenizer learned meaningful subword patterns.


### 2.6 BPE Tokenizer: Encoding and Decoding

In the previous part, we trained a BPE tokenizer on input text to obtain a tokenizer **vocabulary** and a list of **BPE merges**. Now we’ll implement a tokenizer that **loads** a provided vocabulary and merge list and uses them to **encode** and **decode** text to/from token IDs.

#### 2.6.1 Encoding text

The encoding process mirrors how we trained the BPE vocabulary. Steps:

**Step 1: Pre-tokenize.**  
Pre-tokenize the sequence and represent each pre-token as a sequence of **UTF-8 bytes**, just as in BPE training. We merge bytes **within each pre-token** into vocabulary elements, handling each pre-token independently (**no merges across pre-token boundaries**).

**Step 2: Apply the merges.**  
Take the sequence of vocabulary-element merges created during BPE training and **apply them to the pre-tokens in the same order of creation**.


### Example (`bpe_encoding`): BPE encoding example

Suppose our input string is **`'the cat ate'`**.  
Our vocabulary is:
`{0: b' ', 1: b'a', 2: b'c', 3: b'e', 4: b'h', 5: b't', 6: b'th', 7: b' c', 8: b' a', 9: b'the', 10: b' at'}`

Our learned merges are:
`[(b't', b'h'), (b' ', b'c'), (b' ', b'a'), (b'th', b'e'), (b' a', b't')]`.

Pre-tokenization splits the string into: `['the', ' cat', ' ate']`.  
We then apply merges **in order of creation** within each pre-token.

- **Pre-token `'the'`** initially as `[b't', b'h', b'e']`  
  1) Apply `(b't', b'h') → [b'th', b'e']`  
  2) Apply `(b'th', b'e') → [b'the']`  
  → IDs: **[9]**

- **Pre-token `' cat'`** becomes `[b' ', b'c', b'a', b't']`  
  1) Apply `(b' ', b'c') → [b' c', b'a', b't']`  
  → IDs: **[7, 1, 5]**

- **Pre-token `' ate'`** becomes `[b' ', b'a', b't', b'e']`  
  1) `(b' ', b'a') → [b' a', b't', b'e']`  
  2) `(b' a', b't') → [b' at', b'e']`  
  → IDs: **[10, 3]**

**Final encoded sequence:** **`[9, 7, 1, 5, 10, 3]`**.

---

**Special tokens.** Your tokenizer should correctly handle user-defined special tokens when encoding (assuming they were provided at construction time).

**Memory considerations.** When tokenizing very large files/streams, process the text in manageable **chunks** to keep memory complexity constant. Ensure that **tokens do not cross chunk boundaries**; otherwise, the result may differ from naïvely tokenizing the entire sequence in memory.


### 2.6.2 Decoding text

To decode a sequence of integer token IDs back to raw text, look up each ID’s entry in the vocabulary (a **byte sequence**), concatenate them, then decode the bytes to a **Unicode string**.

Note that input IDs are not guaranteed to map to valid Unicode strings (a user could provide any sequence of IDs). If the bytes do **not** produce valid Unicode, replace malformed bytes with the official Unicode replacement character **U+FFFD**. The `errors` argument of `bytes.decode` controls how decoding errors are handled; using `errors='replace'` automatically substitutes malformed data with the replacement marker.


### Problem (`tokenizer`): Implementing the tokenizer *(15 points)*

**Deliverable**  
Implement a **Tokenizer** class that, given a vocabulary and a list of merges,  
- **encodes** text into integer IDs, and  
- **decodes** integer IDs back into text.  

It should also support **user-provided special tokens** (append them to the vocab if missing).

We recommend the following interface:

```python
def __init__(self, vocab, merges, special_tokens=None):
    """Construct from a vocabulary, merges, and optional special tokens."""
````

Parameters:

* `vocab: dict[int, bytes]`
* `merges: list[tuple[bytes, bytes]]`
* `special_tokens: list[str] | None = None`

```python
@classmethod
def from_files(cls, vocab_filepath, merges_filepath, special_tokens=None):
    """Build from serialized vocab/merges (same format as your trainer outputs)."""
```

Additional parameters:

* `vocab_filepath: str`
* `merges_filepath: str`
* `special_tokens: list[str] | None = None`

```python
def encode(self, text: str) -> list[int]:
    """Encode a string into token IDs."""
```

```python
from typing import Iterable, Iterator

def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
    """Given an iterable of strings (e.g., a file handle), lazily yield token IDs.
    Enables memory-efficient tokenization of large files."""
```

```python
def decode(self, ids: list[int]) -> str:
    """Decode a sequence of token IDs into text."""
```

**Testing**
First implement the test adapter at `adapters.get_tokenizer`.
Then run:

```bash
uv run pytest tests/test_tokenizer.py
```

Your implementation should pass all tests.


See the implementation in `tokenizer.py`.

In [39]:
!uv run pytest tests/test_tokenizer.py -v

platform darwin -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0 -- /Users/yihuang/projects/data_science/cd336-a1/cs336-hw1/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/yihuang/projects/data_science/cd336-a1/cs336-hw1
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 25 items                                                             [0m

tests/test_tokenizer.py::test_roundtrip_empty [32mPASSED[0m
tests/test_tokenizer.py::test_empty_matches_tiktoken [32mPASSED[0m
tests/test_tokenizer.py::test_roundtrip_single_character [32mPASSED[0m
tests/test_tokenizer.py::test_single_character_matches_tiktoken [32mPASSED[0m
tests/test_tokenizer.py::test_roundtrip_single_unicode_character [32mPASSED[0m
tests/test_tokenizer.py::test_single_unicode_character_matches_tiktoken [32mPASSED[0m
tests/test_tokenizer.py::test_roundtrip_ascii_string [32mPASSED[0m
tests/test_tokenizer.py::test_ascii_string_matches_tiktoken [32mPASSED[0m
tests/test_tokenizer.py::test_roundtrip