In [7]:
import pandas as pd
from collections import defaultdict
import regex as re

## 2.5 Experimenting with BPE Tokenizer Training


In [8]:
def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    pass

In [None]:
def load_txt_as_str(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_string(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

def get_tok_counts(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def get_element_counts(counts: dict[str, int])-> dict[str, int]:
    element_counts = defaultdict(int)
    for token, count in counts.items():
        elements = tuple([k for k in token])
        element_counts[elements] += count
    return element_counts

def get_pair_counts(element_counts: dict[str, int]) -> dict[tuple[str,str], int]:
    pair_counts = defaultdict(int)
    for elements, count in element_counts.items():
        for i in range(len(elements)-1):
            pair_counts[(elements[i],elements[i+1])] += count
    return pair_counts


def update_element_counts(element_counts: dict[str, int], pair: tuple[str, str]) -> dict[str, int]:
    new_element_counts = {}
    for elements, counts in element_counts.items():
        new_element = []
        elements_len = len(elements)
        index = 0
        while index <= elements_len-1:
            if (index < elements_len-1) and (elements[index] == pair[0]) and (elements[index+1] == pair[1]):
                new_element.append("".join(elements[index:index+2]))
                index += 2
            else:
                new_element.append(elements[index])
                index += 1
        new_element_counts[tuple(new_element)] = counts
    return new_element_counts  

def initiate_vocab(special_tokens: list[str], unique_elements: set[str]) ->  dict[int, bytes]:
    special_tokens_len = len(special_tokens)
    vocab = {i: tok.encode("utf-8") for i, tok in enumerate(special_tokens)}
    for i, elem in enumerate(unique_elements, start=special_tokens_len):
        vocab[i] = elem.encode("utf-8")
    return vocab     

In [None]:
def pre_tokenize(string: str,vocab_size: int,special_tokens: list[str]) -> tuple[dict[int, bytes],list[tuple[bytes, bytes]]]:
    merges = []
    string_list = split_string(string, special_tokens)
    counts = get_tok_counts(string_list)
    element_counts = get_element_counts(counts)
    unique_elements = set().union(*element_counts.keys())
    vocab = initiate_vocab(special_tokens,unique_elements)
    vocab_init_len = len(vocab)
    while vocab_init_len<vocab_size:
        pair_counts = get_pair_counts(element_counts)
        pair = max(pair_counts, key=lambda k: (pair_counts[k], k))
        merges.append(pair)
        element_counts = update_element_counts(element_counts, pair)
        vocab[vocab_init_len] = "".join(pair).encode("utf-8")
        vocab_init_len = len(vocab)
    return vocab, merges


SyntaxError: expected ':' (2681849556.py, line 1)

In [157]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
epochs = 6
vocab_size = 30
string = "hi. i'm yifan li. nice to meet you. this the what when here where"

In [160]:
vocab

{0: b'<|endoftext|>',
 1: b's',
 2: b'c',
 3: b't',
 4: b'f',
 5: b'y',
 6: b'a',
 7: b'o',
 8: b'h',
 9: b'n',
 10: b'e',
 11: b'l',
 12: b'w',
 13: b'r',
 14: b'm',
 15: b'i',
 16: b'.',
 17: b"'",
 18: b'u',
 19: b' ',
 20: b'he',
 21: b' w',
 22: b' t',
 23: b're',
 24: b'here',
 25: b'hi',
 26: b' y',
 27: b'ou',
 28: b'ni',
 29: b'nic'}

In [156]:
merge

[('h', 'e'), (' ', 'w'), (' ', 't'), ('r', 'e'), ('he', 're'), ('h', 'i')]

In [112]:
test_string = "hello! こんにちは!"
utf8_encoded = test_string.encode("utf-8")
utf8_encoded

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'

In [53]:
elements

(' ', 'w', 'h', 'e', 'r', 'e')

In [93]:
string

'u don\'t have to be scared of the loud dog, I\'ll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.\n<|endoftext|>\nOnce upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.\nTom asked his friend, Sam, to help him search for the ball. They looked high and low, but they could not find the ball. Tom said, "I think my ball fell into the pit."\nSam and Tom went close to the pit. They were scared, but they wanted to find the red ball. They looked into the pit, but it was too dark to see. Tom said, "We must go in and search for my ball."\nThey went into the pit to search. It was dark and scary. They could not find the ball. They tried to get out, but the pit was too deep. Tom and Sam were stuck in the pit. They called for help, but no one coul

In [55]:
element_counts

defaultdict(int,
            {('h', 'i'): 1,
             ('.',): 3,
             (' ', 'i'): 1,
             ("'", 'm'): 1,
             (' ', 'y', 'i', 'f', 'a', 'n'): 1,
             (' ', 'l', 'i'): 1,
             (' ', 'n', 'i', 'c', 'e'): 1,
             (' ', 't', 'o'): 1,
             (' ', 'm', 'e', 'e', 't'): 1,
             (' ', 'y', 'o', 'u'): 1,
             (' ', 't', 'h', 'i', 's'): 1,
             (' ', 't', 'h', 'e'): 1,
             (' ', 'w', 'h', 'a', 't'): 1,
             (' ', 'w', 'h', 'e', 'n'): 1,
             (' ', 'h', 'e', 'r', 'e'): 1,
             (' ', 'w', 'h', 'e', 'r', 'e'): 1})

In [148]:
pair_counts

defaultdict(int,
            {'hi': 2,
             ' i': 1,
             "'m": 1,
             ' y': 2,
             'yi': 1,
             'if': 1,
             'fa': 1,
             'an': 1,
             ' l': 1,
             'li': 1,
             ' n': 1,
             'ni': 1,
             'ic': 1,
             'ce': 1,
             ' t': 3,
             'to': 1,
             ' m': 1,
             'me': 1,
             'ee': 1,
             'et': 1,
             'yo': 1,
             'ou': 1,
             'th': 2,
             'is': 1,
             'he': 4,
             ' w': 3,
             'wh': 3,
             'ha': 1,
             'at': 1,
             'en': 1,
             ' h': 1,
             'er': 2,
             're': 2})