In [1]:
import pandas as pd
from collections import defaultdict
import regex as re
from multiprocessing import Pool
from support.find_chunk_boundaries import find_chunk_boundaries
from memory_profiler import profile

## Q1.1 Problem (train_bpe): BPE Tokenizer Training 


In [2]:

def load_txt_as_str(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_string(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def get_tok_counts(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def get_byte_counts(counts: dict[str, int])-> dict[str, int]:
    element_counts = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        element_counts[elements] += count
    return element_counts

def get_pair_counts(element_counts: dict[str, int]) -> dict[tuple[int,int], int]:
    pair_counts = defaultdict(int)
    for elements, count in element_counts.items():
        for i in range(len(elements)-1):
            pair_counts[(elements[i],elements[i+1])] += count
    return pair_counts


def update_element_counts(byte_level_counts: dict[str, int], pair: tuple[int, int], new_index: int) -> dict[str, int]:
    new_byte_level_counts = {}
    for elements, counts in byte_level_counts.items():
        new_element = []
        elements_len = len(elements)
        index = 0
        while index <= elements_len-1:
            if (index < elements_len-1) and (elements[index] == pair[0]) and (elements[index+1] == pair[1]):
                new_element.append(new_index)
                index += 2
            else:
                new_element.append(elements[index])
                index += 1
        new_byte_level_counts[tuple(new_element)] = counts
    return new_byte_level_counts  

def initiate_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab   

def find_max_pair(pair_counts: dict[tuple[int,int], int], vocab:dict[int, bytes]) -> tuple[int, int]:
    max_count = max(pair_counts.values())
    candidate_pairs = [key for key, value in pair_counts.items() if value == max_count]
    def sort_pair(pair):
        index1, index2 = pair
        return(vocab[index1], vocab[index2])
    pair = max(candidate_pairs, key = sort_pair)
    return pair


def pre_tokenize(string: str,special_tokens: list[str]) -> dict[str, int]:
    string_list = split_string(string, special_tokens)
    word_level_counts = get_tok_counts(string_list)
    return word_level_counts

def process_chunk(args):
    chunk_text, special_tokens = args
    return pre_tokenize(chunk_text, special_tokens)

def load_txt_in_chunks(input_path: str, special_tokens: list[str], num_processes: int = 4) -> list[str]:
    with open(input_path, "rb") as f:
        num_processes = 4
        boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

        chunks = []
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            f.seek(start)
            chunk = f.read(end - start).decode("utf-8", errors="ignore")
            chunks.append((chunk, special_tokens))
    return chunks

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    string = load_txt_as_str(input_path)
    word_level_counts = pre_tokenize(string,special_tokens)
    # num_processes = 4
    # chunks = load_txt_in_chunks(input_path, special_tokens, num_processes)
    # with Pool(num_processes) as pool:
    #     results = pool.map(process_chunk, chunks)
        
    # word_level_counts = defaultdict(int)
    # for partial_count in results:
    #     for word, count in partial_count.items():
    #         word_level_counts[word] += count
            
    byte_level_counts = get_byte_counts(word_level_counts)
    vocab = initiate_vocab(special_tokens)
    vocab_len = len(vocab)
    merges = []
    while vocab_len<vocab_size:
        pair_counts = get_pair_counts(byte_level_counts)
        if len(pair_counts) == 0:
            break
        pair = find_max_pair(pair_counts, vocab)
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        byte_level_counts = update_element_counts(byte_level_counts, pair,new_index)
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        vocab[new_index] = new_token
        vocab_len+=1
    return vocab, merges

In [513]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'

epochs = 6
vocab_size = 270
string = "hi. i'm yifan li. nice to meet you.<|endoftext|> this the what when here where"
input_path = r'./data/TinyStoriesV2-GPT4-valid.txt'
string = load_txt_as_str(input_path)

In [None]:
%memit vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

peak memory: 861.81 MiB, increment: 20.62 MiB


## Unit test

In [457]:
input_path = r"./tests/fixtures/tinystories_sample_5M.txt"

In [458]:
vocab, merges = train_bpe(
        input_path=input_path,
        vocab_size=1000,
        special_tokens=["<|endoftext|>"],
    )

In [459]:
vocabs_without_specials = [word for word in vocab.values() if word != b"<|endoftext|>"]
for word_bytes in vocabs_without_specials:
    assert b"<|" not in word_bytes

In [460]:
vocab_size = 1000
special_tokens=["<|endoftext|>"]
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

In [461]:
import pickle

with open("./tests/_snapshots/test_train_bpe_special_tokens.pkl", "rb") as f:
    data = pickle.load(f)
    
data.keys()

dict_keys(['vocab_keys', 'vocab_values', 'merges'])

In [462]:
for i, (my_merge, target_merge) in enumerate(zip(merges,data['merges'])):
    if not my_merge == target_merge:
        print([i, my_merge, target_merge])
        break

## Q1.2 Problem (train_bpe_tinystories): BPE Training on TinyStories 

In [None]:
%load_ext memory_profiler

In [14]:
import time

In [21]:
input_path = r"./data/TinyStoriesV2-GPT4-valid.txt"
special_tokens = ['<|endoftext|>']
vocab_size = 32000

In [22]:
start_time = time.time()
%memit vocab, merges = train_bpe(input_path, vocab_size, special_tokens)
end_time = time.time()
print(f"⏱️ 运行时间: {end_time - start_time:.2f} 秒")

peak memory: 356.20 MiB, increment: 23.55 MiB
⏱️ 运行时间: 85.28 秒
