In [1]:
import torch
import regex as re
from einops import rearrange, reduce, repeat
import os
import multiprocessing
from collections import defaultdict, Counter
from pretokenization_example import find_chunk_boundaries
import math
from tqdm import tqdm



In [None]:
def run_train_bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: list[str],
    max_loops=None,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    # use multi-processing and chunks [LATER]


    vocab = {i: bytes([i]) for i in range(256)}
    curr_vocab_size = 256
    for st in special_tokens:
        vocab[curr_vocab_size] = st.encode("utf-8")
        curr_vocab_size += 1
    
    merges = []

     # The Regex Pattern for Pre-tokenization
    PAT = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    # Think about whether you'd like to replace with a streamed/buffed implementation
    fd = os.open(input_path, os.O_RDONLY)
    data = os.read(fd, os.path.getsize(input_path))
    os.close(fd)
    text = data.decode("utf-8")

    if special_tokens:
        pattern = "|".join(re.escape(tok) for tok in special_tokens)
        chunks = [c for c in re.split(pattern, text) if c]
    else:
        chunks = [text]

    #replace this with finditer later [LATER]
    parts = []
    for c in chunks:
        for tok in re.findall(PAT, c):
            parts.append(tuple(bytes([b]) for b in tok.encode("utf8"))) 

    loop_counter = 0
    total = (vocab_size - curr_vocab_size) if max_loops is None else min(vocab_size - curr_vocab_size, max_loops)
    pbar = tqdm(total=total, desc="BPE")

    while (max_loops is None or loop_counter < max_loops ) and (curr_vocab_size < vocab_size):
        merge_dict = defaultdict(int)
        for part in parts:
            total_letters = len(part)
            k = 0
            while (k+1) < total_letters:
                merge_dict[(part[k], part[k+1])] += 1
                k +=1

        if not merge_dict:
            break
        new_token = max(merge_dict.items(), key=lambda x: (x[1], x[0]))[0] #max is EMPTY [TODO]
        merges.append((new_token[0], new_token[1]))
        new_token = new_token[0] + new_token[1]
        vocab[curr_vocab_size] = new_token
        curr_vocab_size += 1

        def apply_merge(part):
            j = 0
            out = []
            while j+1 < len(part):
                if (part[j] + part[j+1]) == new_token:
                    out.append(new_token)
                    j += 2
                else:
                    out.append(part[j])
                    j += 1
                
            if j < len(part):
                out.append(part[j])
            return tuple(out)

        new_parts = [apply_merge(p) for p in parts]
        if new_parts == parts:
            break
        parts = new_parts
        
        pbar.update(1)
        loop_counter+= 1
        
    pbar.close()
    return (vocab, merges)


In [None]:
run_train_bpe(
    "/Users/virajchhajed/Desktop/everything/fun/cs336/hw-1/data/TinyStoriesV2-GPT4-valid.txt",
    100_000,
    ["<|endoftext|>"],
    1000
)

In [8]:
chunks = ["abc", "def🟥😂"]

PAT = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

parts = []

for c in chunks:
        for tok in re.findall(PAT, c):
            print(tok)
            print(tok.encode("utf8"))
            parts.append(tuple(bytes([b]) for b in tok.encode("utf8"))) 

abc
b'abc'
def
b'def'
🟥😂
b'\xf0\x9f\x9f\xa5\xf0\x9f\x98\x82'


In [9]:
parts

[(b'a', b'b', b'c'),
 (b'd', b'e', b'f'),
 (b'\xf0', b'\x9f', b'\x9f', b'\xa5', b'\xf0', b'\x9f', b'\x98', b'\x82')]

In [18]:
for i in "😂".encode("utf8"):
    print(bytes([i]))

b'\xf0'
b'\x9f'
b'\x98'
b'\x82'
