In [None]:
import torch
import regex as re
from einops import rearrange, reduce, repeat
import os
from multiprocessing import Process, Lock, Pool
from collections import defaultdict, Counter
from pretokenization_example import find_chunk_boundaries
from itertools import chain
from tqdm import tqdm



In [None]:
def worker(args):
        input_path, start, end, special_tokens = args
        with open(input_path, "rb") as f:
            f.seek(start)
            pre_chunk = f.read(end - start).decode("utf-8", errors="ignore")

            if special_tokens:
                pattern = "|".join(re.escape(tok) for tok in special_tokens)
                pre_chunk = [c for c in re.split(pattern, pre_chunk) if c]
            else:
                pre_chunk = [pre_chunk]
            
        PAT = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
        parts = []
        for c in pre_chunk:
            for tok in PAT.findall(c): #replace this with finditer later [LATER]
                parts.append(tuple(bytes([b]) for b in tok.encode("utf8"))) 
        
        return parts

In [None]:
def run_train_bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: list[str],
    max_loops=None,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    # use multi-processing and chunks [LATER]

    vocab = {i: bytes([i]) for i in range(256)}
    curr_vocab_size = 256
    for st in special_tokens:
        vocab[curr_vocab_size] = st.encode("utf-8")
        curr_vocab_size += 1
    
    merges = []
     
    args_list = []
    with open(input_path, "rb") as f:
        num_processes = os.cpu_count() or 1
        boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

        for start, end in zip(boundaries[:-1], boundaries[1:]):
            args_list.append((input_path, start, end, special_tokens))

    #CHECK [LATER]
    num_chunks = max(0, len(args_list))
    num_processes = min(num_processes, max(1, num_chunks))

    with Pool(processes=num_processes) as pool:
        results = pool.map(worker, args_list)
    
    parts = list(chain.from_iterable(results)) #check [LATER]
                
    loop_counter = 0
    total = (vocab_size - curr_vocab_size) if max_loops is None else min(vocab_size - curr_vocab_size, max_loops)
    pbar = tqdm(total=total, desc="BPE")

    while (max_loops is None or loop_counter < max_loops ) and (curr_vocab_size < vocab_size):
        merge_dict = defaultdict(int)
        for part in parts:
            total_letters = len(part)
            k = 0
            while (k+1) < total_letters:
                merge_dict[(part[k], part[k+1])] += 1
                k +=1

        if not merge_dict:
            break
        new_token = max(merge_dict.items(), key=lambda x: (x[1], x[0]))[0] #max is EMPTY [TODO]
        merges.append((new_token[0], new_token[1]))
        new_token = new_token[0] + new_token[1]
        vocab[curr_vocab_size] = new_token
        curr_vocab_size += 1

        def apply_merge(part):
            j = 0
            out = []
            while j+1 < len(part):
                if (part[j] + part[j+1]) == new_token:
                    out.append(new_token)
                    j += 2
                else:
                    out.append(part[j])
                    j += 1
                
            if j < len(part):
                out.append(part[j])
            return tuple(out)

        new_parts = [apply_merge(p) for p in parts]
        if new_parts == parts:
            break
        parts = new_parts
        
        pbar.update(1)
        loop_counter+= 1
        
    pbar.close()
    return (vocab, merges)


In [None]:
if __name__ == "__main__":
    run_train_bpe(
    "/Users/virajchhajed/Desktop/everything/fun/cs336/hw-1/data/TinyStoriesV2-GPT4-valid.txt",
    1000,
    ["<|endoftext|>"],
    1000
    )