In this notebpook, we will discuss about different tokenization techniques available while training a model:
- Word Tokenization
- Sentence Tokenization
- Byte Pair Encoding (BPE)
- SentencePiece
- Tiktoken

In [1]:
%autosave 300
%reload_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Autosaving every 300 seconds


In [2]:
import os

os.chdir(
    "/mnt/batch/tasks/shared/LS_root/mounts/clusters/insights-model-run2/code/Users/soutrik.chowdhury/EraV2_Transformers"
)
print(os.getcwd())

/mnt/batch/tasks/shared/LS_root/mounts/clusters/insights-model-run2/code/Users/soutrik.chowdhury/EraV2_Transformers


Minimal (byte-level) Byte Pair Encoding tokenizer

In [11]:
import urllib.request
url = 'https://raw.githubusercontent.com/karpathy/minbpe/master/tests/taylorswift.txt'
filename = 'data/taylor1.txt'
urllib.request.urlretrieve(url, filename)

('data/taylor1.txt', <http.client.HTTPMessage at 0x7efde44c4eb0>)

In [12]:
with open('data/taylor1.txt', 'r') as file:
    data = file.read()

In [13]:
print(len(data))

185561


In [15]:
import unicodedata

In [17]:
merges = {} # (int, int) -> int
pattern = "" # str
special_tokens = {}

In [18]:
def _build_vocab():
    """Position wise byte encoding for all 256 bytes"""
    # vocab is simply and deterministically derived from merges
    vocab = {idx: bytes([idx]) for idx in range(256)}
    for (p0, p1), idx in merges.items():
        vocab[idx] = vocab[p0] + vocab[p1]
    for special, idx in special_tokens.items():
        vocab[idx] = special.encode("utf-8")
    return vocab

In [19]:
vocab = _build_vocab()

In [21]:
def get_stats(ids, counts=None):
    """
    Given a list of integers, return a dictionary of counts of consecutive pairs
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    Optionally allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]):  # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts


def merge(ids, pair, idx):
    """
    In the list of integers (ids), replace all consecutive occurrences
    of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    newids = []
    i = 0
    while i < len(ids):
        # if not at the very last position AND the pair matches, replace it
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i + 1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

In [26]:
def replace_control_characters(s: str) -> str:
    # we don't want to print control characters
    # which distort the output (e.g. \n or much worse)
    # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
    # http://www.unicode.org/reports/tr44/#GC_Values_Table
    chars = []
    for ch in s:
        if unicodedata.category(ch)[0] != "C":
            chars.append(ch)  # this character is ok
        else:
            chars.append(f"\\u{ord(ch):04x}")  # escape
    return "".join(chars)


def render_token(t: bytes) -> str:
    # pretty print a token, escaping control characters
    s = t.decode("utf-8", errors="replace")
    s = replace_control_characters(s)
    return s

In [22]:
def train(text, vocab_size, verbose=False):
    assert vocab_size >= 256
    # the number of merges is equal to the vocab size minus the number of bytes
    num_merges = vocab_size - 256

    # input text preprocessing
    text_bytes = text.encode("utf-8")  # raw bytes
    ids = list(text_bytes)  # list of integers in range 0..255

    # iteratively merge the most common pairs to create new tokens
    merges = {}  # (int, int) -> int
    vocab = {idx: bytes([idx]) for idx in range(256)}  # int -> bytes
    for i in range(num_merges):
        # count up the number of times every consecutive pair appears
        stats = get_stats(ids)
        # find the pair with the highest count
        pair = max(stats, key=stats.get)
        # mint a new token: assign it the next available id
        idx = 256 + i
        # replace all occurrences of pair in ids with idx
        ids = merge(ids, pair, idx)
        # save the merge
        merges[pair] = idx
        vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
        # prints
        if verbose:
            print(
                f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences"
            )

    # save class variables
    merges = merges  # pair of unicodes merged, used in encode()
    vocab = vocab  # index position : letter/subword used in decode()

    return vocab, merges

In [23]:
vocab, merges = train(data, 512, verbose=True)

merge 1/256: (101, 32) -> 256 (b'e ') had 2981 occurrences
merge 2/256: (44, 32) -> 257 (b', ') had 2961 occurrences
merge 3/256: (100, 32) -> 258 (b'd ') had 2617 occurrences
merge 4/256: (46, 32) -> 259 (b'. ') had 2560 occurrences
merge 5/256: (114, 32) -> 260 (b'r ') had 2428 occurrences
merge 6/256: (50, 48) -> 261 (b'20') had 2365 occurrences
merge 7/256: (115, 32) -> 262 (b's ') had 2053 occurrences
merge 8/256: (105, 110) -> 263 (b'in') had 2006 occurrences
merge 9/256: (111, 110) -> 264 (b'on') had 1815 occurrences
merge 10/256: (114, 105) -> 265 (b'ri') had 1805 occurrences
merge 11/256: (116, 32) -> 266 (b't ') had 1802 occurrences
merge 12/256: (116, 104) -> 267 (b'th') had 1737 occurrences
merge 13/256: (101, 258) -> 268 (b'ed ') had 1736 occurrences
merge 14/256: (257, 261) -> 269 (b', 20') had 1705 occurrences
merge 15/256: (97, 110) -> 270 (b'an') had 1487 occurrences
merge 16/256: (97, 114) -> 271 (b'ar') had 1360 occurrences
merge 17/256: (101, 260) -> 272 (b'er ') ha

In [35]:
def save(file_prefix, pattern, special_tokens, merges, vocab):
    """
    Saves two files: file_prefix.vocab and file_prefix.model
    This is inspired (but not equivalent to!) sentencepiece's model saving:
    - model file is the critical one, intended for load()
    - vocab file is just a pretty printed version for human inspection only
    """
    # write the model: to be used in load() later
    model_file = file_prefix + ".model"
    with open(model_file, "w") as f:
        # write the version, pattern and merges, that's all that's needed
        f.write("minbpe v1\n")
        f.write(f"{pattern}\n")
        # write the special tokens, first the number of them, then each one
        f.write(f"{len(special_tokens)}\n")
        for special, idx in special_tokens.items():
            f.write(f"{special} {idx}\n")
        # the merges dict
        for idx1, idx2 in merges:
            f.write(f"{idx1} {idx2}\n")

    # write the vocab: for the human to look at
    vocab_file = file_prefix + ".vocab"
    inverted_merges = {idx: pair for pair, idx in merges.items()}
    with open(vocab_file, "w", encoding="utf-8") as f:
        for idx, token in vocab.items():
            # note: many tokens may be partial utf-8 sequences
            # and cannot be decoded into valid strings. Here we're using
            # errors='replace' to replace them with the replacement char �.
            # this also means that we couldn't possibly use .vocab in load()
            # because decoding in this way is a lossy operation!
            s = render_token(token)
            # find the children of this token, if any
            if idx in inverted_merges:
                # if this token has children, render it nicely as a merge
                idx0, idx1 = inverted_merges[idx]
                s0 = render_token(vocab[idx0])
                s1 = render_token(vocab[idx1])
                f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
            else:
                # otherwise this is leaf token, just print it
                # (this should just be the first 256 tokens, the bytes)
                f.write(f"[{s}] {idx}\n")

In [36]:
os.makedirs("tokenizers", exist_ok=True)
path = "tokenizers/taylor1"

In [37]:
save(path, pattern, special_tokens, merges, vocab)

In [45]:
def load(model_file):
    """Inverse of save() but only for the model file"""
    assert model_file.endswith(".model")
    # read the model file
    merges = {}
    special_tokens = {}
    idx = 256
    with open(model_file, "r", encoding="utf-8") as f:
        # read the version
        version = f.readline().strip()
        assert version == "minbpe v1"
        # read the pattern
        pattern = f.readline().strip()
        # read the special tokens
        num_special = int(f.readline().strip())
        for _ in range(num_special):
            special, special_idx = f.readline().strip().split()
            special_tokens[special] = int(special_idx)
        # read the merges
        for line in f:
            idx1, idx2 = map(int, line.split())
            merges[(idx1, idx2)] = idx
            idx += 1
    merges = merges
    special_tokens = special_tokens
    vocab = _build_vocab()

    return vocab, merges, special_tokens, pattern

In [47]:
model_file = "tokenizers/taylor1.model"
vocab, merges, special_tokens, pattern = load(model_file)

In [51]:
def decode(ids):
    # given ids (list of integers), return Python string
    text_bytes = b"".join(vocab[idx] for idx in ids)
    text = text_bytes.decode("utf-8", errors="replace")
    return text


def encode(text):
    # given a string text, return the token ids
    text_bytes = text.encode("utf-8")  # raw bytes
    ids = list(text_bytes)  # list of integers in range 0..255
    while len(ids) >= 2:
        # find the pair with the lowest merge index
        stats = get_stats(ids)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        # subtle: if there are no more merges available, the key will
        # result in an inf for every single pair, and the min will be
        # just the first pair in the list, arbitrarily
        # we can detect this terminating case by a membership check
        if pair not in merges:
            break  # nothing else can be merged anymore
        # otherwise let's merge the best pair (lowest merge index)
        idx = merges[pair]
        ids = merge(ids, pair, idx)
    return ids

In [53]:
sentence = "I love you Puchu"

In [54]:
encode(sentence)

[73, 32, 108, 346, 256, 121, 321, 32, 80, 117, 284, 117]

In [55]:
decode(encode(sentence))

'I love you Puchu'