In [186]:
import pandas as pd
from collections import defaultdict
import regex as re

## 2.5 Experimenting with BPE Tokenizer Training


In [339]:
def load_txt_as_str(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_string(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def get_tok_counts(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def get_byte_counts(counts: dict[str, int])-> dict[str, int]:
    element_counts = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        element_counts[elements] += count
    return element_counts

def get_pair_counts(element_counts: dict[str, int]) -> dict[tuple[str,str], int]:
    pair_counts = defaultdict(int)
    for elements, count in element_counts.items():
        for i in range(len(elements)-1):
            pair_counts[(elements[i],elements[i+1])] += count
    return pair_counts


def update_element_counts(byte_level_counts: dict[str, int], pair: tuple[str, str], new_index: int) -> dict[str, int]:
    new_byte_level_counts = {}
    for elements, counts in byte_level_counts.items():
        new_element = []
        elements_len = len(elements)
        index = 0
        while index <= elements_len-1:
            if (index < elements_len-1) and (elements[index] == pair[0]) and (elements[index+1] == pair[1]):
                new_element.append(new_index)
                index += 2
            else:
                new_element.append(elements[index])
                index += 1
        new_byte_level_counts[tuple(new_element)] = counts
    return new_byte_level_counts  

def initiate_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab   

In [None]:
def pre_tokenize(string: str,vocab_size: int,special_tokens: list[str]) -> tuple[dict[int, bytes],list[tuple[bytes, bytes]]]:
    merges = []
    string_list = split_string(string, special_tokens)
    word_level_counts = get_tok_counts(string_list)
    byte_level_counts = get_byte_counts(word_level_counts)
    vocab = initiate_vocab(special_tokens)
    vocab_len = len(vocab)

    while vocab_len<vocab_size:
        pair_counts = get_pair_counts(byte_level_counts)
        while len(pair_counts) == 0:
            break
        pair = max(pair_counts, key=lambda k: (pair_counts[k], k))
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        byte_level_counts = update_element_counts(byte_level_counts, pair,new_index)
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        vocab[new_index] = new_token
        vocab_len+=1
    return vocab, merges

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    string = load_txt_as_str(input_path)
    vocab, merges = pre_tokenize(string, vocab_size,special_tokens)
    return vocab, merges

In [346]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'

epochs = 6
vocab_size = 300
string = "hi. i'm yifan li. nice to meet you.<|endoftext|> this the what when here where"
# input_path = r'./data/test.txt'
# string = load_txt_as_str(input_path)

In [None]:
pre_tokenize(string,vocab_size,special_tokens: list[str]) -> tuple[dict[int, bytes],list[tuple[bytes, bytes]]]:

In [347]:
merges = []
string_list = split_string(string, special_tokens)
word_level_counts = get_tok_counts(string_list)
byte_level_counts = get_byte_counts(word_level_counts)
vocab = initiate_vocab(special_tokens)
special_tokens_len = len(special_tokens)
vocab_len = len(vocab)

while vocab_len<vocab_size:
    pair_counts = get_pair_counts(byte_level_counts)
    pair = max(pair_counts, key=lambda k: (pair_counts[k], k))
    index1, index2 = pair
    new_token = vocab[int(index1)]+vocab[int(index2)]
    new_index = vocab_len
    byte_level_counts = update_element_counts(byte_level_counts, pair,new_index)
    merges.append((vocab[int(index1)], vocab[int(index2)]))
    vocab[new_index] = new_token
    vocab_len+=1

ValueError: max() iterable argument is empty

In [348]:
pair_counts

defaultdict(int, {})

In [344]:
byte_level_counts

{(104, 105): 1,
 (46,): 3,
 (32, 105): 1,
 (39, 109): 1,
 (32, 121, 105, 102, 97, 110): 1,
 (32, 108, 105): 1,
 (32, 110, 105, 99, 101): 1,
 (32, 116, 111): 1,
 (32, 109, 101, 101, 116): 1,
 (32, 121, 111, 117): 1,
 (32, 116, 104, 105, 115): 1,
 (32, 116, 257): 1,
 (32, 119, 104, 97, 116): 1,
 (32, 119, 257, 110): 1,
 (32, 257, 114, 101): 1,
 (32, 119, 257, 114, 101): 1}

In [345]:
new_token

b'he'

In [310]:
def get_byte_counts(counts: dict[str, int])-> dict[str, int]:
    element_counts = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        element_counts[elements] += count
    return element_counts

In [296]:
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

In [297]:
vocab

{0: b'<|endoftext|>',
 1: b'\x00',
 2: b'\x01',
 3: b'\x02',
 4: b'\x03',
 5: b'\x04',
 6: b'\x05',
 7: b'\x06',
 8: b'\x07',
 9: b'\x08',
 10: b'\t',
 11: b'\n',
 12: b'\x0b',
 13: b'\x0c',
 14: b'\r',
 15: b'\x0e',
 16: b'\x0f',
 17: b'\x10',
 18: b'\x11',
 19: b'\x12',
 20: b'\x13',
 21: b'\x14',
 22: b'\x15',
 23: b'\x16',
 24: b'\x17',
 25: b'\x18',
 26: b'\x19',
 27: b'\x1a',
 28: b'\x1b',
 29: b'\x1c',
 30: b'\x1d',
 31: b'\x1e',
 32: b'\x1f',
 33: b' ',
 34: b'!',
 35: b'"',
 36: b'#',
 37: b'$',
 38: b'%',
 39: b'&',
 40: b"'",
 41: b'(',
 42: b')',
 43: b'*',
 44: b'+',
 45: b',',
 46: b'-',
 47: b'.',
 48: b'/',
 49: b'0',
 50: b'1',
 51: b'2',
 52: b'3',
 53: b'4',
 54: b'5',
 55: b'6',
 56: b'7',
 57: b'8',
 58: b'9',
 59: b':',
 60: b';',
 61: b'<',
 62: b'=',
 63: b'>',
 64: b'?',
 65: b'@',
 66: b'A',
 67: b'B',
 68: b'C',
 69: b'D',
 70: b'E',
 71: b'F',
 72: b'G',
 73: b'H',
 74: b'I',
 75: b'J',
 76: b'K',
 77: b'L',
 78: b'M',
 79: b'N',
 80: b'O',
 81: b'P',
 82: b

In [298]:
merges

[(b'g', b'd'),
 (b'\x1f', b's'),
 (b'\x1fs', b'gd'),
 (b'\x1f', b'r'),
 (b'm', b'c'),
 (b'\x1f', b'a'),
 (b'\x1f', b'`'),
 (b'\x1f', b'S'),
 (b'd', b'c'),
 (b'\x1f', b'v'),
 (b'h', b's'),
 (b'`', b'q'),
 (b'\x1f', b'o'),
 (b'\x1fs', b'n'),
 (b'\x1f', b'e'),
 (b'\x1f`', b'mc'),
 (b'k', b'k'),
 (b'\x1fo', b'hs'),
 (b'\x1fS', b'gd'),
 (b'\x1f', b'k'),
 (b'n', b't'),
 (b'h', b'm'),
 (b'`', b'kk'),
 (b'\x1f', b'g'),
 (b'\x1fSgd', b'x'),
 (b'\x1fv', b'`'),
 (b'\x1fa', b'`kk'),
 (b'n', b'l'),
 (b'\x1fv`', b'r'),
 (b'\x1fg', b'h'),
 (b'\x1fS', b'nl'),
 (b'\x1fr', b'`'),
 (b'd', b'q'),
 (b'\x1f', b'hm'),
 (b'\x1f', b'm'),
 (b'\x1f', b'l'),
 (b'\x1f', b'c'),
 (b'\x1f', b'b'),
 (b'\x1fa', b't'),
 (b'\x1fat', b's'),
 (b'r', b's'),
 (b'k', b'd'),
 (b'`', b'l')]

In [264]:
byte_level_counts = byte_level_counts
pair = pair
new_index = new_index

In [265]:
new_byte_level_counts = {}
for elements, counts in byte_level_counts.items():
    new_element = []
    elements_len = len(elements)
    index = 0
    while index <= elements_len-1:
        if (index < elements_len-1) and (elements[index] == pair[0]) and (elements[index+1] == pair[1]):
            new_element.append(new_index)
            index += 2
        else:
            new_element.append(elements[index])
            index += 1
    new_byte_level_counts[tuple(new_element)] = counts

{0: b'<|endoftext|>',
 1: b'\x00',
 2: b'\x01',
 3: b'\x02',
 4: b'\x03',
 5: b'\x04',
 6: b'\x05',
 7: b'\x06',
 8: b'\x07',
 9: b'\x08',
 10: b'\t',
 11: b'\n',
 12: b'\x0b',
 13: b'\x0c',
 14: b'\r',
 15: b'\x0e',
 16: b'\x0f',
 17: b'\x10',
 18: b'\x11',
 19: b'\x12',
 20: b'\x13',
 21: b'\x14',
 22: b'\x15',
 23: b'\x16',
 24: b'\x17',
 25: b'\x18',
 26: b'\x19',
 27: b'\x1a',
 28: b'\x1b',
 29: b'\x1c',
 30: b'\x1d',
 31: b'\x1e',
 32: b'\x1f',
 33: b' ',
 34: b'!',
 35: b'"',
 36: b'#',
 37: b'$',
 38: b'%',
 39: b'&',
 40: b"'",
 41: b'(',
 42: b')',
 43: b'*',
 44: b'+',
 45: b',',
 46: b'-',
 47: b'.',
 48: b'/',
 49: b'0',
 50: b'1',
 51: b'2',
 52: b'3',
 53: b'4',
 54: b'5',
 55: b'6',
 56: b'7',
 57: b'8',
 58: b'9',
 59: b':',
 60: b';',
 61: b'<',
 62: b'=',
 63: b'>',
 64: b'?',
 65: b'@',
 66: b'A',
 67: b'B',
 68: b'C',
 69: b'D',
 70: b'E',
 71: b'F',
 72: b'G',
 73: b'H',
 74: b'I',
 75: b'J',
 76: b'K',
 77: b'L',
 78: b'M',
 79: b'N',
 80: b'O',
 81: b'P',
 82: b

In [218]:
i=2
vocab[i+special_tokens_len] = bytes([i])
vocab

{0: b'<|endoftext|>', 2: b'\x00', 3: b'\x02'}

In [None]:
vocab: dict[int, bytes]  = {i:bytes([i]) for i in range(256)}

In [185]:
import pathlib
FIXTURES_PATH = (pathlib.Path(__file__).resolve().parent) / "fixtures"

NameError: name '__file__' is not defined