## tokenizing with byte-pair encoding

### *Neural Machine Translation of Rare Words with Subword Units, Sennrich et. al (2015)*

Byte-pair encoding (BPE) is a compression algorithm that was first described by Philip Gage in 1994. This original version of BPE encoded strings of text into smaller strings of text by merging and replacing the highest-frequency contiguous pair of bytes. Byte-pair encoding was then [adapted by Rico Sennrich, Barry Haddow and Alexandra Birch](https://arxiv.org/pdf/1508.07909) to optimize neural machine translation, and namely to solve the problem of translation of rare words. word-level models, relying on complete words, can't possibly generate or translate words they had never seen before, a problem that becomes bigger when dealing with various dialects, alphabets, etc.

Sennrich et.al showed that translation of rare words is possible through their encoding via subword units, and additionally created a vocabulary, (a set of tokens, each one of variable-length) that was fixed in size but that could handle any input given to it. the argument was that, the translation of some word is possible and 'transparent' for any competent translator, even if the word is novel or unknown, based on an analysis of known subwords contained within that word (morphemes or phonemes).

> a very simplified view of the algorithm:
>
> - Start with a vocabulary that is just the raw symbols (characters)
> - Repeatedly merge the most-common adjacent symbol pairs into new, longer tokens.
> - Stop when you hit a fixed vocab size (e.g. 50k)

### *Language Models are Unsupervised Multitask Learners, Radford et.al (2019)*

in the context of language models, this paper introduced BPE as a mechanism for language models, recognizing and leveraging the many desirable properties that the algorithm had in the context of language modeling. alternatives to BPE included classic word-level tokenisers that, as Sennrich et.al suggested, choken on funky slang, creative spelling, emojis, and anything else they had never seen.

> nice compression (one token per known word) but brittle. any typoe or neologism explodes into an `<unk>` or a pile of fallback characters.

character-level models fixed this coverage, but exploded sequence length and more easily lost track of higher-level patterns.

> antidisestablishmentarianism -> 28 tokens. long sequences, slower training, limits how much context we can put into the model's window.

BPE is described to be a middle ground between both: shorter than characters, nimbler than words. It produces fewer tokens than a pure character model (shorter sequences) while staying more adaptable than a pure word model.







### implementing byte-pair encoding

Strings are sequences of Unicode code points. [Unicode](https://en.wikipedia.org/wiki/Unicode) is a character encoding standard defining more than 150,000 characters and 168 scripts.

> From the Python documentation: *Textual data in Python is handled with str objects, or strings. Strings are immutable sequences of Unicode code points.*

the vast majority of text available in the internet is encoded using Unicode. The `ord()` function returns the number representing the unicode code of a given character.

In [1]:
from docutils.utils.math import pick_math_environment

characs = ['ლ', 'პ', '🌞', '🔥', '༂', '༅']
for char in characs:
    print(f"Character: {char} -> Unicode code: {ord(char)}")

Character: ლ -> Unicode code: 4314
Character: პ -> Unicode code: 4318
Character: 🌞 -> Unicode code: 127774
Character: 🔥 -> Unicode code: 128293
Character: ༂ -> Unicode code: 3842
Character: ༅ -> Unicode code: 3845


computers ultimately read and write bytes, and not abstract code points like the ones shown above. An *encoding* is a mapping turning each code point sequence to a byte sequence (a process called serialization) and back (deserialization). The Unicode Standard defines three main character encoding standards used for electronic communication:

- UTF-8 -> 1 byte unit size, variable length (1-4 bytes per code point), typically used in the web, APIs and Unix.
- UTF-16 -> 2 bytes unit size (variable length).
- UTF-32 -> 4 bytes unit size (fixed-width)

almost every single webpage is transmitted as UTF-8, mostly due to the following reasons:
- it is the only encoding standard that is backward compatible with ASCII (another encoding standard used in a lot of legacy tooling) meaning that any text file encoded in ASCII can be decoded as UTF-8 to get exactly the same result.
- it is space efficient, at least for latin corpora. English text, for instance, stays at around 1 byte per character, whereas UTF-16 or 32 double or quadruple it (see below).

> a note on something i was personally confused with:
>
> - UTF-8 is 'variable-length' because different characters need 1-4 bytes to encode. but **each individual byte** can still store only 0-255 ($2^8 - 1$), being those 256 byte values what a byte-level tokenizer treats as its entire vocabulary. Do not confuse symbols with byte tokens. UTF-8 has 256 distinct byte tokens, which you can combine to do way more than 256 characters.

In [2]:
from sys import getsizeof

samples = {
    "ascii": "Hello, world!",
    "mutlilang": "こんにちは世界🌍", # hello world in japanese + emoji
    "emojis": "🔥🌞💧🌱" # emojis
}

print("{name:10} | utf‑8 | utf‑16‑le | utf‑32‑le | Python str (CPython 3.12) sizeof")
print("---------")
for name, s in samples.items():
    utf8 = len(s.encode("utf-8"))
    u16 = len(s.encode("utf-16-le"))
    u32 = len(s.encode("utf-32-le"))
    pyobj = getsizeof(s)
    print(f"{name:10} | {utf8:5d} | {u16:10d} | {u32:10d} | {pyobj:27d}")

{name:10} | utf‑8 | utf‑16‑le | utf‑32‑le | Python str (CPython 3.12) sizeof
---------
ascii      |    13 |         26 |         52 |                          62
mutlilang  |    25 |         18 |         32 |                         108
emojis     |    16 |         16 |         16 |                          92


Notice that, for ASCII-heavy text (latin-based corpora), UTF-8 is 2x-4x smaller than UTF-18/32. let's show the proportion of zero-valued bytes and preview the first 32 bytes of each encoding:

In [3]:
def hex_preview(blob:bytes, limit=32):
    head = blob[:limit]
    return " ".join(f"{b:02x}" for b in head) + (" ..." if len(blob) > limit else "")

for name, s in samples.items():
    print(name.upper())
    for enc in ("utf-8", "utf-16-le", "utf-32-le"):
        blob = s.encode(enc)
        zeros = blob.count(0)
        print(f"{enc:10} | {len(blob):2d} bytes | {zeros:2d} zero bytes | {100*zeros/len(blob):5.1f}% zeros")
        print("  ", hex_preview(blob))
    print()



ASCII
utf-8      | 13 bytes |  0 zero bytes |   0.0% zeros
   48 65 6c 6c 6f 2c 20 77 6f 72 6c 64 21
utf-16-le  | 26 bytes | 13 zero bytes |  50.0% zeros
   48 00 65 00 6c 00 6c 00 6f 00 2c 00 20 00 77 00 6f 00 72 00 6c 00 64 00 21 00
utf-32-le  | 52 bytes | 39 zero bytes |  75.0% zeros
   48 00 00 00 65 00 00 00 6c 00 00 00 6c 00 00 00 6f 00 00 00 2c 00 00 00 20 00 00 00 77 00 00 00 ...

MUTLILANG
utf-8      | 25 bytes |  0 zero bytes |   0.0% zeros
   e3 81 93 e3 82 93 e3 81 ab e3 81 a1 e3 81 af e4 b8 96 e7 95 8c f0 9f 8c 8d
utf-16-le  | 18 bytes |  0 zero bytes |   0.0% zeros
   53 30 93 30 6b 30 61 30 6f 30 16 4e 4c 75 3c d8 0d df
utf-32-le  | 32 bytes | 15 zero bytes |  46.9% zeros
   53 30 00 00 93 30 00 00 6b 30 00 00 61 30 00 00 6f 30 00 00 16 4e 00 00 4c 75 00 00 0d f3 01 00

EMOJIS
utf-8      | 16 bytes |  0 zero bytes |   0.0% zeros
   f0 9f 94 a5 f0 9f 8c 9e f0 9f 92 a7 f0 9f 8c b1
utf-16-le  | 16 bytes |  0 zero bytes |   0.0% zeros
   3d d8 25 dd 3c d8 1e df 3d d8 a7 dc 3

note that each ASCII character in UTF-16-LE introduces one zero byte of padding , while UTF-32-LE introduces three. that is 50% and 75% 'wasted' space respectively, for corporas at least mostly dominated by latin symbols. even more, for emoji-heavy text (which already needs more than 3 bytes in UTF-8), we still pay a 25% overhead moving to UTF-32.

### using raw utf-8 bytes as a tokens looks tempting...

Using raw UTF-8 bytes as tokens is an option. at the end of the day, it provides us with a tiny embedding table (only 256 possible bytes) while still overcoming the out-of-vocabulary risk. additionally, encoding/decoding any given text with it is trivial, and already implemented for us:

In [4]:
text = "to be or not to be"
utf_8_text_bytes = b"to be or not to be"
print(f"Byte literal: {utf_8_text_bytes}")
print(f"Class of byte literal: {type(utf_8_text_bytes)}")
print(f"Elements of byte literal:\n {list(utf_8_text_bytes)}")
og_text = utf_8_text_bytes.decode("utf-8")
print(f"Original text: {og_text}")
print(f"Class of original text: {type(og_text)}")
print(f"Elements of original text:\n {list(og_text)}")

Byte literal: b'to be or not to be'
Class of byte literal: <class 'bytes'>
Elements of byte literal:
 [116, 111, 32, 98, 101, 32, 111, 114, 32, 110, 111, 116, 32, 116, 111, 32, 98, 101]
Original text: to be or not to be
Class of original text: <class 'str'>
Elements of original text:
 ['t', 'o', ' ', 'b', 'e', ' ', 'o', 'r', ' ', 'n', 'o', 't', ' ', 't', 'o', ' ', 'b', 'e']


This would then be a byte-level tokenizer, simply breaking the UTF-8 stream at every individual byte boundary. higher-plane characters stretching over four bytes just means four consecutive byte-tokens--it does not change the size of the vocab itself. as the above introduced papers explain, this would introduce three problems:

- **longer sequences**: to represent any given character we need at least 1 byte (every non-latin or non-ASCII code point expanding to 2-4 bytes). this would force us towards very long sequences of tokens. the complexity of transformers, the inner mechanism of a large language model, scales by $L^2$ in sequence length $L$. the longer the inputs, the more attention work, and the more reduced, in terms of symbols, our context window would be.
- **tokens do not line up with meaning**: a single emoji, such as ⚽️, lands in the model as 2-4 unrelated integers. since we want the network to first learn how bytes compose into code points (or into words and longer grammatical structures) before it can learn anything about semantics, this is a bit of wasted capacity and training time.

In [5]:
list("⚽".encode("utf-8"))

[226, 154, 189]

- **the very slim vocab hurts compression**: byte-level needs length of word in bytes tokens. while the size of the embedding table shrinks, the computation costs in terms of later activations, positional embeddings, and attention operations grow far faster. in the example below, for instance, a text in japanese explodes from 8 logical characters to 25 byte-tokens, and emojis explode even more.

In [6]:
def describe(name, s):
    b = s.encode("utf-8")
    b_tokens = list(b)
    cp_tokens = list(s)

    print(f"{name:10} | bytes: {len(b_tokens):2d} tokens | code-points: {len(cp_tokens):2d} tokens")
    print(f" unique byte-tokens used: {len(set(b_tokens)):3d} / 256")
    print(f"preview bytes: {b_tokens[:16]}")
    print()

for n, s in samples.items():
    describe(n, s)

ascii      | bytes: 13 tokens | code-points: 13 tokens
 unique byte-tokens used:  10 / 256
preview bytes: [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33]

mutlilang  | bytes: 25 tokens | code-points:  8 tokens
 unique byte-tokens used:  16 / 256
preview bytes: [227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175, 228]

emojis     | bytes: 16 tokens | code-points:  4 tokens
 unique byte-tokens used:   9 / 256
preview bytes: [240, 159, 148, 165, 240, 159, 140, 158, 240, 159, 146, 167, 240, 159, 140, 177]



let's see a very basic approximation of the additional compute associated with this explosion:

In [7]:
def transformer_flops(n_layers, dim, seq_len, factor=2):
    return n_layers * factor * (seq_len ** 2) * dim

L  = 12           # layers
H  = 768          # hidden size
char_len = 8 # japanese text numbers
byte_len = 25

print("Char-level  :", f"{transformer_flops(L, H, char_len):,.2f} floating point operations needed.")
print("Byte-level  :", f"{transformer_flops(L, H, byte_len):,.2f} floating point operations needed.")


Char-level  : 1,179,648.00 floating point operations needed.
Byte-level  : 11,520,000.00 floating point operations needed.


#### utf-8 aware $\ne$ byte-bound

We want to start and keep UTF-8 as a base because it gives us three very important guarantees: most web-based text and any future Unicode character already has a UTF-8 spelling (our tokenizer would never go out of date, unless the Unicode standard is replaced), we have a method for loss-free round-trips (text -> tokens -> text is always possible and exact) and UTF-8 can be processed left-to-right without backtracking, good for data pipelines, distributed and sharded datasets.

therefore, using the same 256-byte alphabet, we want to apply some sort of clever operations on top that shrinks average sequence length, caps vocab at a manageable size, remains fully reversible and stays future-proof.

### actually implementing byte-pair encoding

let's use morse code to see Byte-Pair encoding in action. the morse code alphabet is composed only of two symbols (and the space), composed each of a single byte.

In [8]:
morse_alphabet = [".", "-"]
for token in morse_alphabet:
    print(f"Byte-encodings of morse code ({token}): {list(token.encode('utf-8'))}")

Byte-encodings of morse code (.): [46]
Byte-encodings of morse code (-): [45]


suppose for instance we wanted to encode ['hello hello' in Morse](https://morsecode.world/international/translator.html):

``` bash
.... . .-.. .-.. --- .... . .-.. .-.. ---

# hello hello in morse
```





In [9]:
hello_hello = ".... . .-.. .-.. --- .... . .-.. .-.. ---"

# encode to raw bytes
byte_tokens = list(hello_hello.encode("utf-8"))
unique_byte_tokens = set(byte_tokens)

# map each byte back to its printable char
unique_chars = [bytes([b]).decode("utf-8") for b in unique_byte_tokens]


print(f"Raw byte list:\n{byte_tokens}\n")
print(f"Sequence length: {len(byte_tokens)}")
print(f"Unique byte tokens: {len(unique_byte_tokens)} → {unique_byte_tokens}")
print(f"Unique char tokens: {len(unique_chars)}   → {unique_chars}")

Raw byte list:
[46, 46, 46, 46, 32, 46, 32, 46, 45, 46, 46, 32, 46, 45, 46, 46, 32, 45, 45, 45, 32, 46, 46, 46, 46, 32, 46, 32, 46, 45, 46, 46, 32, 46, 45, 46, 46, 32, 45, 45, 45]

Sequence length: 41
Unique byte tokens: 3 → {32, 45, 46}
Unique char tokens: 3   → [' ', '-', '.']


so the Morse-encodeed 'hello hello' uses three distincy byte tokens only. that is a very small vocab: the embedding table size is very small, but sequence length explodes (41 symbols for just two words). this is exactly the trade-off we described earlier.

to fix this, the BPE iteratively finds the most common adjacent pair of characters, fuses them into a single, new glyph and replaces every occurrence. here we find the most common adjacent pairs of characters:

In [10]:
from collections import Counter

counts = Counter(zip(hello_hello, hello_hello[1:]))
print("Counts of adjacent pairs:")
for c in counts:
    print(f"\n{c}, {counts[c]}")
    ((a, b)), freq = counts.most_common(1)[0]
print("Most common pair:")
print(f"'{a}{b}' with freq: {freq}") if freq > 1 else None  # stop when every pair is unique

Counts of adjacent pairs:

('.', '.'), 10

('.', ' '), 8

(' ', '.'), 7

('.', '-'), 4

('-', '.'), 4

(' ', '-'), 2

('-', '-'), 4

('-', ' '), 1
Most common pair:
'..' with freq: 10


In [11]:
# helper function
def most_common_pair(seq):
    counts = Counter(zip(seq, seq[1:]))
    ((a, b)), freq = counts.most_common(1)[0]
    return (a, b, freq) if freq > 1 else None  # stop when every pair is unique

most_common_pair(list(hello_hello))

('.', '.', 10)

then, we mint a new glyph using the most common pair, and we replace every occurrence of it. For example, we mint a new symbol called 'A', and 'A' replaces every '..'. For instance, the starting symbols `....` become `AA`.

In [12]:
def fuse_once(seq, pair, new_sym):
    a, b = pair
    out, i = [], 0
    while i < len(seq):
        if i + 1 < len(seq) and seq[i]==a and seq[i+1]==b:
            out.append(new_sym); i += 2
        else:
            out.append(seq[i]);  i += 1
    return out

vocabulary = {'.', '-'}
rules = []

pair_found = most_common_pair(list(hello_hello))
a, b, freq = pair_found
rules.append(f"A = {a}{b}")
vocabulary.add('A')

new_hello_hello = fuse_once(hello_hello, (a,b), "A")
print(f"New String: {new_hello_hello}")
print(f"Rule: {rules}")

New String: ['A', 'A', ' ', '.', ' ', '.', '-', 'A', ' ', '.', '-', 'A', ' ', '-', '-', '-', ' ', 'A', 'A', ' ', '.', ' ', '.', '-', 'A', ' ', '.', '-', 'A', ' ', '-', '-', '-']
Rule: ['A = ..']


the symbol `A` then becomes part of the vocabulary, and the same process is run iteratively until we have hit a predetermined number of tokens, or alternatively until the remaining bytes have adjacent pairs that do not occur more than once. the next iteration then would look like this.

In [13]:
pair_found = most_common_pair(list(new_hello_hello))
a, b, freq = pair_found
rules.append(f"B = {a}{b}")
vocabulary.add('B')

new_hello_hello = fuse_once(new_hello_hello, (a,b), "B")
print(f"New String: {new_hello_hello}")
print(f"Rule: {rules}")

New String: ['A', 'B', '.', ' ', '.', '-', 'B', '.', '-', 'B', '-', '-', '-', ' ', 'A', 'B', '.', ' ', '.', '-', 'B', '.', '-', 'B', '-', '-', '-']
Rule: ['A = ..', 'B = A ']


so the sequence `A `, (an A followed by a space, or `.. `) gets replaced by a new symbol `B` which is then added to our vocabulary. running this algorithm iteratively until the end would result in something like the following:

In [14]:
import string

vocabulary = {'.', '-', ' '}
rules    = []
fresh    = iter(string.ascii_uppercase)
step     = 1

print("Start :", "".join(hello_hello))
while True:
    found = most_common_pair(hello_hello)
    if not found:
        break
    a, b, freq = found
    new_sym    = next(fresh)          # mint A, B, C, ...
    rules.append(f"{new_sym} = {a}{b}")
    vocabulary.add(new_sym)
    hello_hello = fuse_once(hello_hello, (a, b), new_sym)

    print(f"Step {step:>2} : replace '{a}{b}' → '{new_sym}' "
          f"(appeared {freq}×)  →  {''.join(hello_hello)}")
    step += 1

print("\nCompressed sequence :", "".join(hello_hello))
print("Merge rules (oldest → newest):")
for r in rules:
    print(" ", r)

Start : .... . .-.. .-.. --- .... . .-.. .-.. ---
Step  1 : replace '..' → 'A' (appeared 10×)  →  AA . .-A .-A --- AA . .-A .-A ---
Step  2 : replace 'A ' → 'B' (appeared 6×)  →  AB. .-B.-B--- AB. .-B.-B---
Step  3 : replace 'B.' → 'C' (appeared 4×)  →  AC .-C-B--- AC .-C-B---
Step  4 : replace '--' → 'D' (appeared 4×)  →  AC .-C-BD- AC .-C-BD-
Step  5 : replace 'AC' → 'E' (appeared 2×)  →  E .-C-BD- E .-C-BD-
Step  6 : replace 'E ' → 'F' (appeared 2×)  →  F.-C-BD- F.-C-BD-
Step  7 : replace 'F.' → 'G' (appeared 2×)  →  G-C-BD- G-C-BD-
Step  8 : replace 'G-' → 'H' (appeared 2×)  →  HC-BD- HC-BD-
Step  9 : replace 'HC' → 'I' (appeared 2×)  →  I-BD- I-BD-
Step 10 : replace 'I-' → 'J' (appeared 2×)  →  JBD- JBD-
Step 11 : replace 'JB' → 'K' (appeared 2×)  →  KD- KD-
Step 12 : replace 'KD' → 'L' (appeared 2×)  →  L- L-
Step 13 : replace 'L-' → 'M' (appeared 2×)  →  M M

Compressed sequence : M M
Merge rules (oldest → newest):
  A = ..
  B = A 
  C = B.
  D = --
  E = AC
  F = E 
  G = F.
 

to be clear, the vocabulary keeps every symbol that was minted along the way. The final sequence being `M M` reflects (apart from the original string being a repetition of the same structure with a space in between, namely 'hello hello') that the final sequence shrank to two tokens (`M`, `space`). The embedding table still has the full amount of original and generated entries in the vocabulary:

In [15]:
print("Final compressed sequence:", ''.join(hello_hello))
print("Vocabulary size:", len(vocabulary))
print("Vocabulary symbols:", sorted(vocabulary))


Final compressed sequence: M M
Vocabulary size: 16
Vocabulary symbols: [' ', '-', '.', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M']


More importantly, during BPE training, we decide when to stop merging. using the example above, starting out from our original vocabulary of `['.', '-', ' ']`, we can decide that we want our final vocabulary size to be of 10, and therefore we would end after the 7th step, being our final vocabulary `['.', '-', ' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G']`. the earlier we stop it, the more tiny our vocabulary is, and the longer our long token sequences. conversely, if we merge until there is nothing else to merge, our vocabulary grows in size, but we have ultra-short sequences.

It is also important to note that the example thus far has used characters directly to perform the byte-pair encoding (we should be referring to it so far as character-pair encoding, really...). in our case, we get away with it because all of the characters in the original corpora (`hello hello` in Morse) are single-byte characters (see above). however, remember that UTF-8 encodes Unicode characters into up to 4 bytes. *byte* pair encoding never sees characters, it sees raw byte values (0-255) only.

### byte pair encoding for a larger corpora

i decided to use the [Gutenberg Poetry Corpus](https://github.com/aparrish/gutenberg-poetry-corpus) by Allison Parish. it is composed of more than 3 million lines of poetry obtained from publicly available books in [Project Gutenberg](https://gutenberg.org/). The `line` column are single lines of poetry, while the `gutenberg_id` is the ID of the Project Gutenberg book that the line is coming from.

In [16]:
from datasets import load_dataset
ds = load_dataset("biglam/gutenberg-poetry-corpus", split="train", streaming=True)

# streaming=True is set to work with the dataset without downloading it. rather, the data is streamed as we need it (as we iterate over the dataset).

In [17]:
print(ds)

IterableDataset({
    features: ['line', 'gutenberg_id'],
    num_shards: 1
})


let's take 100k random lines as a start:

In [18]:
import itertools
sample = list(itertools.islice(ds.shuffle(buffer_size=100_000, seed=42), 100_000))
corpus = [list(row["line"].encode("utf-8")) for row in sample] # raw bytes

In [19]:
corpus[:5] #5 poetry lines selected at random --> their raw bytes

[[83,
  111,
  32,
  116,
  104,
  114,
  101,
  97,
  116,
  110,
  39,
  100,
  32,
  104,
  101,
  101,
  44,
  32,
  98,
  117,
  116,
  32,
  83,
  65,
  84,
  65,
  78,
  32,
  116,
  111,
  32,
  110,
  111,
  32,
  116,
  104,
  114,
  101,
  97,
  116,
  115],
 [83,
  111,
  32,
  97,
  115,
  32,
  116,
  104,
  101,
  105,
  32,
  98,
  101,
  32,
  111,
  102,
  32,
  100,
  111,
  117,
  98,
  108,
  101,
  32,
  101,
  110,
  116,
  101,
  110,
  116,
  101,
  58],
 [79,
  110,
  101,
  32,
  114,
  111,
  119,
  32,
  111,
  102,
  32,
  114,
  101,
  100,
  32,
  110,
  111,
  115,
  116,
  114,
  105,
  108,
  115,
  32,
  116,
  104,
  97,
  116,
  32,
  115,
  99,
  101,
  110,
  116,
  32,
  98,
  97,
  116,
  116,
  108,
  101,
  45,
  102,
  117,
  109,
  101,
  115,
  46],
 [84,
  111,
  32,
  116,
  121,
  114,
  97,
  110,
  116,
  115,
  32,
  111,
  116,
  104,
  101,
  114,
  115,
  32,
  104,
  97,
  118,
  101,
  32,
  116,
  104,
  101,
  105,
  114,
  32

In [20]:
from itertools import chain

corpus_lines_str = [row["line"] for row in sample]
total_chars = sum(len(line) for line in corpus_lines_str)
unique_chars = {ch for line in corpus_lines_str for ch in line}

# chain takes a series of iterables and returns one iterable

byte_seq = list(chain.from_iterable(corpus))
lines_sampled = len(corpus)
total_byte_number = len(byte_seq)
unique_byte_number = len(set(byte_seq))

print(f"Lines sampled: {lines_sampled}")
print(f"Total **bytes**: {total_byte_number:,}")
print(f"Total unique bytes: {unique_byte_number}")
print()
print(f"Total **characters** : {total_chars:,}")
print(f"Unique characters num : {len(unique_chars):,}")
print(f"Avg bytes per char:{total_byte_number/total_chars:4.2f}")

Lines sampled: 100000
Total **bytes**: 3,836,123
Total unique bytes: 101

Total **characters** : 3,836,075
Unique characters num : 99
Avg bytes per char:1.00


So, in 100,000 randomly selected poetry lines, there is a total of 3,836,123 bytes, out of which only 101 are unique. That means that only 101/256 $\approx$ 40% of the possible byte codes ever appear. but as we already know, even though our memory footprint is excellent, the numbers also exhibit the massive sequence redundancy we are now familiar with.

An avg. byte per char of $\approx$ 1 indicates that the corpus is almost pure ASCII, and that there are pretty much no multi-byte UTF-8 characters to blame for length. what inflates the sequence is therefore not at all UTF-8 overhead, but rather simply the decision to start with single-byte tokens (we have 3,836,123 bytes out of only 101 unique ones!!!). we need to dramatically reduce compute. as i mentioned earlier, the cost of having more vocabulary is linear, while that of having longer sequence lengths (which are operated on by the transformer) is quadratic.

let's take better care of our vocabulary, and spend more memory on it, by using our byte pair algorithm on it:

In [24]:
from pathlib import Path

CACHE_DIR = Path("bpe_cache")
CACHE_DIR.mkdir(exist_ok=True)
BYTE_VOCAB_SIZE = 256 #UTF8

In [26]:
import pickle, json, numpy as np

def most_common_pair(seq):
    counts = Counter(zip(seq, seq[1:]))
    ((a, b)), freq = counts.most_common(1)[0]
    return (a, b, freq) if freq > 1 else None  # stop when every pair is unique

def fuse_once(seq, pair, new_sym):
    a, b = pair
    out, i = [], 0
    while i < len(seq):
        if i + 1 < len(seq) and seq[i]==a and seq[i+1]==b:
            out.append(new_sym); i += 2
        else:
            out.append(seq[i]);  i += 1
    return out

def train_bpe(corpus, max_merges: int, record_every:int = 5):
    if isinstance(corpus, list) and all(isinstance(x, int) for x in corpus):
        seq = corpus[:]
    else:
        seq = list(chain.from_iterable(corpus))

    vocab = set(range(BYTE_VOCAB_SIZE))
    merges, history = [], []
    next_id = BYTE_VOCAB_SIZE # so our first added merge will be 256

    for i in range(max_merges):
        pair = most_common_pair(seq)
        if not pair:
            break
        a, b, _ = pair # dont really need the freq
        new_sym = next_id; next_id += 1
        merges.append((a, b))
        vocab.add(new_sym)
        seq = fuse_once(seq, (a, b), new_sym)

        if i % record_every == 0 or i == max_merges-1:
            history.append({
                "merge_step": i+1,
                "seq_len": len(seq),
                "vocab_size": len(vocab)
            })

    return merges, seq, vocab, history

def run_and_cache(name, corpus, max_merges:int, **kwargs):
    """
    this is just for me to run the bpe 'trainer' once, cache it to disk and return everything.
    """
    stem = f"{name}_{max_merges}"
    files = {
        "merges": CACHE_DIR/ f"{stem}_merges.pkl",
        "compressed": CACHE_DIR/ f"{stem}_seq.npy",
        "vocab": CACHE_DIR/ f"{stem}_vocab.json",
        "history": CACHE_DIR/ f"{stem}_history.json",
    }

    # if everything is already on disk
    if all(p.exists() for p in files.values()):
        with open(files["merges"], "rb") as f: merges = pickle.load(f)
        seq = np.load(files["compressed"])
        with open(files["vocab"]) as f: vocab = set(json.load(f))
        with open(files["history"]) as f: history = json.load(f)
        return merges, list(seq), vocab, history

    merges, seq, vocab, history = train_bpe(corpus, max_merges, **kwargs)

    # dump everything
    with open(files["merges"], "wb") as f: pickle.dump(merges, f, protocol=5)
    np.save(files["compressed"], np.array(seq, dtype=np.uint32))
    with open(files["vocab"],   "w") as f: json.dump(sorted(vocab), f)
    with open(files["history"], "w") as f: json.dump(history, f, indent=2)

    return merges, seq, vocab, history
    

In [27]:
max = 500
experiments = {}

merges, seq, vocab, history = run_and_cache("poetry", corpus, max, record_every=5)
experiments[max] = {"seq":seq, "vocab":vocab, "history":history}
print(f"{max:4d} merges ➜ {len(seq):,} tokens, vocab {len(vocab)}")


 500 merges ➜ 1,692,762 tokens, vocab 756


In [29]:
merges, seq, vocab, history = run_and_cache("poetry", corpus, max, record_every=5)
print(f"amount of merges learned : {len(merges):,}")
print(f"amount of compressed tokens : {len(seq):,}")
print(f"size of final vocabulary: {len(vocab)}")
print(f"final vocabulary:\n {vocab}")

amount of merges learned : 500
amount of compressed tokens : 1,692,762
size of final vocabulary: 756
final vocabulary:
 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,

Running our BPE algorithm for 500 merges yields a final vocabulary size of 756 (the 256 original ones + 500 new ones). notice also the new sequence length, down from 3.8 million to around 1.7 million. but we need to be precise about what was compressed here and why we care.

- the 1.7 million is token count, and not byte file size. before BPE mostly every byte was a separate model (ASCII-based corpora), and the model needed to process a sequence 3.8 million bytes long.
- after 500 merges, we introduced 500 new symbols; the sequence now contains only 1.7 million tokens.
- we eliminated ~55% of the token-level redundancy the model would otherwise see. but **we did not shrink the file on disk**. actually, storing each token id as a 32-bit integer (4 bytes), the 'compressed' representation is actually bigger (1.7 million x 4 > 3.8).

just remembers that BPE is not a storage compression scheme, it is a computational one. it trades a modest increase in vocabulary size for a large drop in sequence length. that trade is profitable because modern autoregressive models are sequence-length bound.

