In [419]:
import pandas as pd
from collections import defaultdict
import regex as re

## 2.5 Experimenting with BPE Tokenizer Training


In [None]:

def load_txt_as_str(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_string(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def get_tok_counts(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def get_byte_counts(counts: dict[str, int])-> dict[str, int]:
    element_counts = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        element_counts[elements] += count
    return element_counts

def get_pair_counts(element_counts: dict[str, int]) -> dict[tuple[int,int], int]:
    pair_counts = defaultdict(int)
    for elements, count in element_counts.items():
        for i in range(len(elements)-1):
            pair_counts[(elements[i],elements[i+1])] += count
    return pair_counts


def update_element_counts(byte_level_counts: dict[str, int], pair: tuple[int, int], new_index: int) -> dict[str, int]:
    new_byte_level_counts = {}
    for elements, counts in byte_level_counts.items():
        new_element = []
        elements_len = len(elements)
        index = 0
        while index <= elements_len-1:
            if (index < elements_len-1) and (elements[index] == pair[0]) and (elements[index+1] == pair[1]):
                new_element.append(new_index)
                index += 2
            else:
                new_element.append(elements[index])
                index += 1
        new_byte_level_counts[tuple(new_element)] = counts
    return new_byte_level_counts  

def initiate_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab   

def find_max_pair(pair_counts: dict[tuple[int,int], int], vocab:dict[int, bytes]) -> tuple[int, int]:
    max_count = max(pair_counts.values())
    candidate_pairs = [key for key, value in pair_counts.items() if value == max_count]
    def sort_pair(pair):
        index1, index2 = pair
        return(vocab[index1], vocab[index2])
    pair = max(candidate_pairs, key = sort_pair)
    return pair


In [None]:
def pre_tokenize(string: str,vocab_size: int,special_tokens: list[str]) -> tuple[dict[int, bytes],list[tuple[bytes, bytes]]]:
    merges = []
    string_list = split_string(string, special_tokens)
    word_level_counts = get_tok_counts(string_list)
    byte_level_counts = get_byte_counts(word_level_counts)
    vocab = initiate_vocab(special_tokens)
    vocab_len = len(vocab)

    while vocab_len<vocab_size:
        pair_counts = get_pair_counts(byte_level_counts)
        if len(pair_counts) == 0:
            break
        pair = find_max_pair(pair_counts, vocab)
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        byte_level_counts = update_element_counts(byte_level_counts, pair,new_index)
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        vocab[new_index] = new_token
        vocab_len+=1
    return vocab, merges

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    string = load_txt_as_str(input_path)
    vocab, merges = pre_tokenize(string, vocab_size,special_tokens)
    return vocab, merges

In [None]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'

epochs = 6
vocab_size = 270
string = "hi. i'm yifan li. nice to meet you.<|endoftext|> this the what when here where"
input_path = r'./data/test.txt'
string = load_txt_as_str(input_path)

In [360]:
vocab, merges = pre_tokenize(string,vocab_size,special_tokens)

In [364]:
vocab, merges = train_bpe(input_path,vocab_size,special_tokens)

## Self Test

In [None]:
merges = []
string_list = split_string(string, special_tokens)
word_level_counts = get_tok_counts(string_list)
byte_level_counts = get_byte_counts(word_level_counts)
vocab = initiate_vocab(special_tokens)
vocab_len = len(vocab)

pair_counts = get_pair_counts(byte_level_counts)

In [425]:
pair_counts = {
    (100,200):2,
    (101,201):2,
    (102,202):1
}

In [426]:
pair = max(pair_counts, key=lambda k: (pair_counts[k], k))
pair

(101, 201)

## Unit test

In [402]:
input_path = r"./tests/fixtures/tinystories_sample_5M.txt"

In [403]:
vocab, merges = train_bpe(
        input_path=input_path,
        vocab_size=1000,
        special_tokens=["<|endoftext|>"],
    )

In [404]:
vocabs_without_specials = [word for word in vocab.values() if word != b"<|endoftext|>"]
for word_bytes in vocabs_without_specials:
    assert b"<|" not in word_bytes

In [405]:
vocab_size = 1000
special_tokens=["<|endoftext|>"]
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

In [406]:
import pickle

with open("./tests/_snapshots/test_train_bpe_special_tokens.pkl", "rb") as f:
    data = pickle.load(f)
    
data.keys()

dict_keys(['vocab_keys', 'vocab_values', 'merges'])

In [407]:
for i, (my_merge, target_merge) in enumerate(zip(merges,data['merges'])):
    if not my_merge == target_merge:
        print([i, my_merge, target_merge])
        break

[34, (b' s', b'a'), (b'i', b'm')]


In [428]:
merges

[]

In [427]:
data['merges']

[(b'h', b'e'),
 (b' ', b't'),
 (b' ', b'a'),
 (b' ', b's'),
 (b' ', b'w'),
 (b'n', b'd'),
 (b' t', b'he'),
 (b'e', b'd'),
 (b' ', b'b'),
 (b' t', b'o'),
 (b' a', b'nd'),
 (b' ', b'h'),
 (b' ', b'f'),
 (b'i', b'n'),
 (b' w', b'a'),
 (b' ', b'T'),
 (b'i', b't'),
 (b'r', b'e'),
 (b'o', b'u'),
 (b' ', b'l'),
 (b' ', b'd'),
 (b' ', b'c'),
 (b' ', b'p'),
 (b'a', b'y'),
 (b' wa', b's'),
 (b'e', b'r'),
 (b' ', b'm'),
 (b'o', b'm'),
 (b' ', b'he'),
 (b' T', b'he'),
 (b'i', b's'),
 (b' ', b'n'),
 (b'o', b'n'),
 (b'a', b'r'),
 (b'i', b'm'),
 (b' s', b'a'),
 (b'l', b'l'),
 (b'i', b'd'),
 (b' h', b'a'),
 (b' ', b'g'),
 (b' ', b'S'),
 (b'a', b't'),
 (b'in', b'g'),
 (b'o', b't'),
 (b'e', b'n'),
 (b'a', b'n'),
 (b'l', b'e'),
 (b'o', b'r'),
 (b'i', b'r'),
 (b' ', b'H'),
 (b'a', b'm'),
 (b'e', b't'),
 (b' ', b'it'),
 (b' t', b'h'),
 (b'i', b'g'),
 (b' The', b'y'),
 (b'i', b'l'),
 (b' ', b'in'),
 (b' H', b'e'),
 (b' p', b'l'),
 (b' ', b'"'),
 (b'o', b'w'),
 (b'v', b'er'),
 (b'r', b'i'),
 (b' ', b'u'),
 (

In [413]:
[list(b'i'),list(b' s')]

[[105], [32, 115]]

In [None]:
A=merges
B=data['merges']
only_in_A = list(set(B) - set(A))
only_in_A

[]

In [429]:
input_path = r"./tests/fixtures/tinystories_sample_5M.txt"
vocab_size = 1000
special_tokens=["<|endoftext|>"]

In [430]:
string = load_txt_as_str(input_path)

In [434]:
merges = []
string_list = split_string(string, special_tokens)
word_level_counts = get_tok_counts(string_list)
byte_level_counts = get_byte_counts(word_level_counts)
vocab = initiate_vocab(special_tokens)
vocab_len = len(vocab)

for i in range(34):
    pair_counts = get_pair_counts(byte_level_counts)
    pair = max(pair_counts, key=lambda k: (pair_counts[k], k))
    index1, index2 = pair
    new_token = vocab[int(index1)]+vocab[int(index2)]
    new_index = vocab_len
    byte_level_counts = update_element_counts(byte_level_counts, pair,new_index)
    merges.append((vocab[int(index1)], vocab[int(index2)]))
    vocab[new_index] = new_token
    vocab_len+=1

In [454]:
find_max_pair(pair_counts,vocab)

(105, 109)

In [438]:
pair_counts = get_pair_counts(byte_level_counts)
pair = max(pair_counts, key=lambda k: (pair_counts[k], k))
pair

(260, 97)

In [449]:
vocab[260]

b' s'

(105, 109)

In [450]:
candidate_pairs

[(260, 97), (105, 109)]

In [442]:
[list(b'i'),list(b'm')]

[[105], [109]]

In [443]:
pair_counts[(105,109)]

21022

In [None]:
pair_counts[(260,97)]

21022

In [445]:
pair_counts

defaultdict(int,
            {(277, 289): 849,
             (39, 116): 2657,
             (268, 97): 19290,
             (97, 118): 5863,
             (118, 101): 9391,
             (265, 101): 10932,
             (260, 99): 2389,
             (99, 290): 1741,
             (290, 264): 2006,
             (32, 111): 13295,
             (111, 102): 7122,
             (276, 275): 548,
             (275, 100): 1445,
             (277, 111): 8068,
             (111, 103): 4272,
             (32, 73): 9003,
             (39, 108): 105,
             (108, 108): 20883,
             (279, 114): 1641,
             (114, 111): 6797,
             (111, 116): 16913,
             (116, 101): 3943,
             (101, 99): 5646,
             (99, 116): 2077,
             (32, 121): 9875,
             (121, 275): 8113,
             (34, 46): 248,
             (283, 111): 2753,
             (111, 108): 5493,
             (108, 101): 15975,
             (269, 101): 3936,
             (101, 108): 8755,
   

{0: b'<|endoftext|>',
 1: b'\x00',
 2: b'\x01',
 3: b'\x02',
 4: b'\x03',
 5: b'\x04',
 6: b'\x05',
 7: b'\x06',
 8: b'\x07',
 9: b'\x08',
 10: b'\t',
 11: b'\n',
 12: b'\x0b',
 13: b'\x0c',
 14: b'\r',
 15: b'\x0e',
 16: b'\x0f',
 17: b'\x10',
 18: b'\x11',
 19: b'\x12',
 20: b'\x13',
 21: b'\x14',
 22: b'\x15',
 23: b'\x16',
 24: b'\x17',
 25: b'\x18',
 26: b'\x19',
 27: b'\x1a',
 28: b'\x1b',
 29: b'\x1c',
 30: b'\x1d',
 31: b'\x1e',
 32: b'\x1f',
 33: b' ',
 34: b'!',
 35: b'"',
 36: b'#',
 37: b'$',
 38: b'%',
 39: b'&',
 40: b"'",
 41: b'(',
 42: b')',
 43: b'*',
 44: b'+',
 45: b',',
 46: b'-',
 47: b'.',
 48: b'/',
 49: b'0',
 50: b'1',
 51: b'2',
 52: b'3',
 53: b'4',
 54: b'5',
 55: b'6',
 56: b'7',
 57: b'8',
 58: b'9',
 59: b':',
 60: b';',
 61: b'<',
 62: b'=',
 63: b'>',
 64: b'?',
 65: b'@',
 66: b'A',
 67: b'B',
 68: b'C',
 69: b'D',
 70: b'E',
 71: b'F',
 72: b'G',
 73: b'H',
 74: b'I',
 75: b'J',
 76: b'K',
 77: b'L',
 78: b'M',
 79: b'N',
 80: b'O',
 81: b'P',
 82: b

In [218]:
i=2
vocab[i+special_tokens_len] = bytes([i])
vocab

{0: b'<|endoftext|>', 2: b'\x00', 3: b'\x02'}

In [None]:
vocab: dict[int, bytes]  = {i:bytes([i]) for i in range(256)}

In [185]:
import pathlib
FIXTURES_PATH = (pathlib.Path(__file__).resolve().parent) / "fixtures"

NameError: name '__file__' is not defined