In [7]:
import torch
import regex as re
from einops import rearrange, reduce, repeat
import os
import multiprocessing
from collections import defaultdict, Counter
from pretokenization_example import find_chunk_boundaries
import math
from tqdm import tqdm



In [8]:
test_corpus = '''low low low low low
lower lower widest widest widest
newest newest newest newest newest newest'''

In [None]:
def run_train_bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: list[str],
    max_loops=None,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    # use multi-processing and chunks [LATER]


    vocab = {i: bytes([i]) for i in range(256)}
    curr_vocab_size = 256
    for st in special_tokens:
        vocab[curr_vocab_size] = st.encode("utf-8")
        curr_vocab_size += 1
    
    merges = []

     # The Regex Pattern for Pre-tokenization
    PAT = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    # Think about whether you'd like to replace with a streamed/buffed implementation
    fd = os.open(input_path, os.O_RDONLY)
    data = os.read(fd, os.path.getsize(input_path))
    os.close(fd)
    text = data.decode("utf-8")

    #gotta strip the text of the end of text symbol here [LATER]

    #replace this with finditer later [LATER]
    # parts = re.findall(PAT, text)
    parts = re.findall(r"\S+", test_corpus)
    parts = [tuple(i) for i in parts]

    loop_counter = 0
    pbar = tqdm(desc="BPE")
    while (max_loops is None or loop_counter < max_loops ) and (curr_vocab_size < vocab_size):
        merge_dict = defaultdict(int)
        for part in parts:
            total_letters = len(part)
            k = 0
            while (k+1) < total_letters:
                merge_dict[(part[k], part[k+1])] += 1
                k +=1

        #lexiographically higherst occuring one:
        if not merge_dict:
            break
        new_token = max(merge_dict.items(), key=lambda x: (x[1], x[0]))[0] #max is EMPTY [TODO]
        merges.append(new_token)
        new_token = ''.join(new_token)
        vocab[curr_vocab_size] = new_token.encode("utf8")
        curr_vocab_size += 1

        # continue from here
        to_be_removed = []
        to_be_added = []
        for i in parts:
            lenght = len(i)
            j = 0
            while (j+1) < lenght:
                if (i[j] + i[j+1]) == new_token:
                    to_be_removed.append(i)
                    new_i = i[:j] + (new_token,) + i[j+2:]
                    to_be_added.append(new_i)
                j+= 1

        for i in to_be_removed:
            parts.remove(i)
        parts.extend(to_be_added)
        
        pbar.update(1)
        loop_counter+= 1
        
    pbar.close()
    return (vocab, merges)


            
    

In [22]:
run_train_bpe(
    "/Users/virajchhajed/Desktop/everything/fun/cs336/hw-1/data/TinyStoriesV2-GPT4-valid.txt",
    100_000,
    ["<|endoftext|>"],
    1000
)



BPE: 12it [00:00, 17829.13it/s]

[('l', 'o', 'w'), ('l', 'o', 'w'), ('l', 'o', 'w'), ('l', 'o', 'w'), ('l', 'o', 'w'), ('l', 'o', 'w', 'e', 'r'), ('l', 'o', 'w', 'e', 'r'), ('w', 'i', 'd', 'e', 's', 't'), ('w', 'i', 'd', 'e', 's', 't'), ('w', 'i', 'd', 'e', 's', 't'), ('n', 'e', 'w', 'e', 's', 't'), ('n', 'e', 'w', 'e', 's', 't'), ('n', 'e', 'w', 'e', 's', 't'), ('n', 'e', 'w', 'e', 's', 't'), ('n', 'e', 'w', 'e', 's', 't'), ('n', 'e', 'w', 'e', 's', 't')]
defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 8, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3, ('e', 's'): 9, ('s', 't'): 9, ('n', 'e'): 6, ('e', 'w'): 6})
[('l', 'o', 'w'), ('l', 'o', 'w'), ('l', 'o', 'w'), ('l', 'o', 'w'), ('l', 'o', 'w'), ('l', 'o', 'w', 'e', 'r'), ('l', 'o', 'w', 'e', 'r'), ('w', 'i', 'd', 'e', 'st'), ('w', 'i', 'd', 'e', 'st'), ('w', 'i', 'd', 'e', 'st'), ('n', 'e', 'w', 'e', 'st'), ('n', 'e', 'w', 'e', 'st'), ('n', 'e', 'w', 'e', 'st'), ('n', 'e', 'w', 'e', 'st'), ('n', 'e', 'w', 'e', 'st'), ('n', 'e', 'w'




({0: b'\x00',
  1: b'\x01',
  2: b'\x02',
  3: b'\x03',
  4: b'\x04',
  5: b'\x05',
  6: b'\x06',
  7: b'\x07',
  8: b'\x08',
  9: b'\t',
  10: b'\n',
  11: b'\x0b',
  12: b'\x0c',
  13: b'\r',
  14: b'\x0e',
  15: b'\x0f',
  16: b'\x10',
  17: b'\x11',
  18: b'\x12',
  19: b'\x13',
  20: b'\x14',
  21: b'\x15',
  22: b'\x16',
  23: b'\x17',
  24: b'\x18',
  25: b'\x19',
  26: b'\x1a',
  27: b'\x1b',
  28: b'\x1c',
  29: b'\x1d',
  30: b'\x1e',
  31: b'\x1f',
  32: b' ',
  33: b'!',
  34: b'"',
  35: b'#',
  36: b'$',
  37: b'%',
  38: b'&',
  39: b"'",
  40: b'(',
  41: b')',
  42: b'*',
  43: b'+',
  44: b',',
  45: b'-',
  46: b'.',
  47: b'/',
  48: b'0',
  49: b'1',
  50: b'2',
  51: b'3',
  52: b'4',
  53: b'5',
  54: b'6',
  55: b'7',
  56: b'8',
  57: b'9',
  58: b':',
  59: b';',
  60: b'<',
  61: b'=',
  62: b'>',
  63: b'?',
  64: b'@',
  65: b'A',
  66: b'B',
  67: b'C',
  68: b'D',
  69: b'E',
  70: b'F',
  71: b'G',
  72: b'H',
  73: b'I',
  74: b'J',
  75: b'K',
  76: b'

- how does multiprocessing work?
