In [57]:
import torch
import regex as re
from einops import rearrange, reduce, repeat
import os
import multiprocessing
from collections import defaultdict, Counter
from pretokenization_example import find_chunk_boundaries

In [71]:
def run_train_bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: list[str],
    **kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    # use multi-processing and chunks [LATER]


    vocab = {i: bytes([i]) for i in range(256)}
    curr_vocab_size = 256
    for st in special_tokens:
        vocab[curr_vocab_size] = st.encode("utf-8")
        curr_vocab_size += 1
    
    merges = []

    while curr_vocab_size < vocab_size:

        # The Regex Pattern for Pre-tokenization
        PAT = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

        # Think about whether you'd like to replace with a streamed/buffed implementation
        fd = os.open(input_path, os.O_RDONLY)
        data = os.read(fd, os.path.getsize(input_path))
        os.close(fd)
        text = data.decode("utf-8")

        #gotta strip the text of the end of text symbol here [LATER]

        #replace this with finditer later [LATER]
        parts = re.findall(PAT, text)
        # parts = re.findall(r"\S+", test_corpus)
        parts = [tuple(i) for i in parts]

    
        for i in range(6):

            merge_dict = defaultdict(int)
            for part in parts:
                total_letters = len(part)
                k = 0
                while (k+1) < total_letters:
                    merge_dict[(part[k], part[k+1])] += 1
                    k +=1

            #lexiographically higherst occuring one:
            new_token = max(merge_dict.items(), key=lambda x: (x[1], x[0]))[0]
            merges.append(new_token)
            new_token = ''.join(new_token)
            vocab[curr_vocab_size] = new_token.encode("utf8")
            curr_vocab_size += 1

            # continue from here
            to_be_removed = []
            to_be_added = []
            for i in parts:
                lenght = len(i)
                j = 0
                while (j+1) < lenght:
                    if (i[j] + i[j+1]) == new_token:
                        to_be_removed.append(i)
                        new_i = i[:j] + (new_token,) + i[j+2:]
                        to_be_added.append(new_i)
                    j+= 1

            for i in to_be_removed:
                parts.remove(i)
            parts.extend(to_be_added)
            
        break

    return (vocab, merges)


            
    

In [None]:
run_train_bpe(
    "/Users/virajchhajed/Desktop/everything/fun/cs336/hw-1/data/TinyStoriesV2-GPT4-valid.txt",
    100_000,
    ["<|endoftext|>"]
)

- how does multiprocessing work?
