In [1]:
import pandas as pd
from collections import defaultdict
import regex as re
from multiprocessing import Pool
from support.find_chunk_boundaries import find_chunk_boundaries
from memory_profiler import profile
import time, tracemalloc
from dataclasses import dataclass

## Q1.1 Problem (train_bpe): BPE Tokenizer Training 


### Version 1.0

In [2]:
from collections import defaultdict
import regex as re


def load_txt_as_str(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_string(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def get_tok_counts(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def get_byte_counts(counts: dict[str, int])-> dict[str, int]:
    element_counts = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        element_counts[elements] += count
    return element_counts

def get_pair_counts(element_counts: dict[str, int]) -> dict[tuple[int,int], int]:
    pair_counts = defaultdict(int)
    for elements, count in element_counts.items():
        for i in range(len(elements)-1):
            pair_counts[(elements[i],elements[i+1])] += count
    return pair_counts


def update_element_counts(byte_level_counts: dict[str, int], pair: tuple[int, int], new_index: int) -> dict[str, int]:
    new_byte_level_counts = {}
    for elements, counts in byte_level_counts.items():
        new_element = []
        elements_len = len(elements)
        index = 0
        while index <= elements_len-1:
            if (index < elements_len-1) and (elements[index] == pair[0]) and (elements[index+1] == pair[1]):
                new_element.append(new_index)
                index += 2
            else:
                new_element.append(elements[index])
                index += 1
        new_byte_level_counts[tuple(new_element)] = counts
    return new_byte_level_counts  

def initiate_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab   

def find_max_pair(pair_counts: dict[tuple[int,int], int], vocab:dict[int, bytes]) -> tuple[int, int]:
    max_count = max(pair_counts.values())
    candidate_pairs = [key for key, value in pair_counts.items() if value == max_count]
    def sort_pair(pair):
        index1, index2 = pair
        return(vocab[index1], vocab[index2])
    pair = max(candidate_pairs, key = sort_pair)
    return pair


def pre_tokenize(string: str,vocab_size: int,special_tokens: list[str]) -> tuple[dict[int, bytes],list[tuple[bytes, bytes]]]:
    merges = []
    string_list = split_string(string, special_tokens)
    word_level_counts = get_tok_counts(string_list)
    byte_level_counts = get_byte_counts(word_level_counts)
    vocab = initiate_vocab(special_tokens)
    vocab_len = len(vocab)

    while vocab_len<vocab_size:
        pair_counts = get_pair_counts(byte_level_counts)
        if len(pair_counts) == 0:
            break
        pair = find_max_pair(pair_counts, vocab)
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        byte_level_counts = update_element_counts(byte_level_counts, pair,new_index)
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        vocab[new_index] = new_token
        vocab_len+=1
    return vocab, merges

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    string = load_txt_as_str(input_path)
    vocab, merges = pre_tokenize(string, vocab_size,special_tokens)
    return vocab, merges

### Version 2.0

In [None]:
Pair = tuple[int,int]
Encoded_Token = tuple[int, ...]
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [None]:

@dataclass(frozen=True)
class TokenMergePlan():
    old_token: Encoded_Token
    new_token: Encoded_Token
    count: int
    pair_positions: list[int]

class PairFreqsDelta():
    inc: defaultdict[Pair, int]
    inc: defaultdict[Pair, int]
    def __init__(self):
        self.inc = defaultdict(int)
        self.dec = defaultdict(int)

class PairInhereitDelta():
    add: defaultdict[Pair,set[Encoded_Token]]
    remove: defaultdict[Pair,set[Encoded_Token]]
    def __init__(self):
        self.add = defaultdict(set[Encoded_Token])
        self.remove = defaultdict(set[Encoded_Token])

In [None]:
def read_text_file(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_by_special_tokens(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

def count_tokens(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def encode_and_count_tokens(counts: dict[str, int])-> dict[Encoded_Token, int]:
    encoded_token_freqs = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        encoded_token_freqs[elements] += count
    return encoded_token_freqs

def build_initial_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab

def get_byte_pairs(encoded_token_freqs: dict[Encoded_Token, int]) -> dict[Pair, int]:
    pair_freqs = defaultdict(int)
    for tok, count in encoded_token_freqs.items():
        for i in range(len(tok)-1):
            pair_freqs[(tok[i],tok[i+1])] += count
    return pair_freqs

def get_byte_pairs_inhereit(encoded_token_freqs: dict[Encoded_Token, int]) -> dict[Pair, set[Encoded_Token]]:
    pair_inhereit = defaultdict(set)
    for tok, count in encoded_token_freqs.items():
        for i in range(len(tok)-1):
            pair_inhereit[(tok[i],tok[i+1])].add(tok)
    return pair_inhereit

def select_merge_pair(pair_freqs: dict[Pair, int], vocab:dict[int, bytes]) -> Pair:
    max_count = max(pair_freqs.values())
    candidate_pairs = [key for key, value in pair_freqs.items() if value == max_count]
    def sort_pair(pair):
        index1, index2 = pair
        return(vocab[index1], vocab[index2])
    pair = max(candidate_pairs, key = sort_pair)
    return pair

def update_encoded_token(encoded_token: Encoded_Token, pair: Pair, new_index: int) -> Encoded_Token:
    result = []
    i = 0
    while i < len(encoded_token):
        if i < len(encoded_token) - 1 and (encoded_token[i], encoded_token[i + 1]) == pair:
            result.append(new_index)
            i += 2
        else:
            result.append(encoded_token[i])
            i += 1
    return tuple(result)

def find_subtuple_index(sequence: tuple, subseq: tuple) -> list[int]:
    position = []
    subseq_len = len(subseq)
    for i in range(len(sequence)-subseq_len+1):
        if sequence[i:i+subseq_len] == subseq:
            position.append(i)
    return position

def remove_or_decrement_pair(pair_freqs: dict[Pair, int], pair: Pair, count: int) -> dict[Pair, int]:
    updated_pair_freqs = pair_freqs.copy()
    if updated_pair_freqs[pair] == count:
        del updated_pair_freqs[pair]
    else:
        updated_pair_freqs[pair] -= count
    return updated_pair_freqs

def build_merge_plan(tok_need_update: set[Encoded_Token], encoded_token_freqs: dict[Encoded_Token, int], pair: Pair, new_index: int) -> list[TokenMergePlan]:
    plan = []
    for encoded_token in tok_need_update:
        new_encoded_token = update_encoded_token(encoded_token, pair, new_index)
        count = encoded_token_freqs[encoded_token]
        pair_positions = find_subtuple_index(encoded_token,pair)
        plan.append(TokenMergePlan(encoded_token,new_encoded_token,count,pair_positions))
    return plan

def update_encoded_token_freqs(plan: list[TokenMergePlan], encoded_token_freqs: dict[Encoded_Token, int]) ->  dict[Encoded_Token, int]:
    new_encoded_token_freqs = encoded_token_freqs.copy()
    for item in plan:
        del new_encoded_token_freqs[item.old_token]
        new_encoded_token_freqs[item.new_token] = item.count
    return new_encoded_token_freqs

def compute_freqs_deltas(plan: list[TokenMergePlan], new_index:int) -> PairFreqsDelta:
    pair_freqs_d = PairFreqsDelta()
    for item in plan:
        old_token = item.old_token
        count = item.count
        for pos in item.pair_positions:
            if pos > 0:
                pre_token = old_token[pos-1]
                old_pair = (pre_token,old_token[pos])
                new_pair = (pre_token, new_index)
                pair_freqs_d.dec[old_pair] += count
                pair_freqs_d.inc[new_pair] += count
            if pos < len(old_token)-2:
                pos_token = old_token[pos+2]
                old_pair = (old_token[pos+1],pos_token)
                new_pair = (new_index, pos_token)
                pair_freqs_d.dec[old_pair] += count
                pair_freqs_d.inc[new_pair] += count
    return pair_freqs_d

def compute_inhereit_deltas(plan: list[TokenMergePlan]) -> PairInhereitDelta:
    pair_inhereit_d = PairInhereitDelta()
    for item in plan:
        if len(item.old_token) > 1:
            for old_pair in zip(item.old_token,item.old_token[1:]):
                pair_inhereit_d.remove[old_pair].add(item.old_token)
        if len(item.new_token) > 1:
            for new_pair in zip(item.new_token,item.new_token[1:]):
                pair_inhereit_d.add[new_pair].add(item.new_token)
    return pair_inhereit_d 

def exclude_pair_from_dict(d: dict, pair: Pair):
    new_d = d.copy()
    del new_d[pair]
    return new_d

def update_pair_freqs(pair_freqs: dict[Pair, int], pair_freqs_d: PairFreqsDelta):
    new_pair_freqs = pair_freqs.copy()
    for key, value in pair_freqs_d.dec.items():
        new_pair_freqs = remove_or_decrement_pair(new_pair_freqs, key, value)
    for key, value in pair_freqs_d.inc.items():
        new_pair_freqs[key]+=value
    return new_pair_freqs

def update_pair_inhereit(pair_inhereit: dict[Pair, set[Encoded_Token]], pair_inhereit_d: PairInhereitDelta):
    new_pair_inhereit = pair_inhereit.copy()
    for key, value in pair_inhereit_d.remove.items():
        new_pair_inhereit[key] -= value
    for key, value in pair_inhereit_d.add.items():
        new_pair_inhereit[key] = new_pair_inhereit[key] | value
    return new_pair_inhereit

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    vocab = build_initial_vocab(special_tokens)
    vocab_len = len(vocab)
    merges = []

    string = read_text_file(input_path)
    string_list = split_by_special_tokens(string,special_tokens)
    token_count = count_tokens(string_list)
    encoded_token_freqs = encode_and_count_tokens(token_count)
    
    pair_freqs = get_byte_pairs(encoded_token_freqs)
    pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

    while vocab_len < vocab_size:
        pair = select_merge_pair(pair_freqs, vocab)
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        vocab[new_index] = new_token
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        tok_need_update = pair_inhereit[pair]
        plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
        encoded_token_freqs = update_encoded_token_freqs(plan, encoded_token_freqs)
        pair_freqs_d = compute_freqs_deltas(plan, new_index)
        pair_freqs = exclude_pair_from_dict(pair_freqs, pair)
        pair_freqs = update_pair_freqs(pair_freqs, pair_freqs_d)
        pair_inhereit_d = compute_inhereit_deltas(plan)
        pair_inhereit = exclude_pair_from_dict(pair_inhereit, pair)
        pair_inhereit = update_pair_inhereit(pair_inhereit, pair_inhereit_d)
        vocab_len+=1
    return vocab, merges

## Version 王少东

In [None]:
def train_bpe(
    input_path: str | os.PathLike = "data/TinyStoriesV2-GPT4-valid.txt",
    vocab_size: int = 1000,
    special_tokens: list[str] = ["<|endoftext|>", "<|startoftext|>"],
) -> tuple[Vocab, list[BytePair]]:
    """
    Learn BPE merges from corpus:
    - Initialize vocab with all single bytes (0..255) plus special tokens appended.
    - Pretokenize and build word frequency counts.
    - Repeatedly count adjacent pair frequencies, select the most frequent pair
      (break ties lexicographically), apply merge, and continue until the
      requested size is reached.
    - Return (id_to_bytes, merges).
    Keep this simple; implement the details yourself.
    """
    vocab = {i: bytes([i]) for i in range(256)}
    for token in special_tokens:
        vocab[len(vocab)] = token.encode("utf-8")

    # Pattern for GPT-2 style pretokenization
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    # Pretokenize and build word frequency counts
    # Chunked pretokenization (serial), align to the required split token
    split_token_bytes = b"<|endoftext|>"
    with open(input_path, "rb") as fbin:
        num_chunks = 100 # max(1, os.cpu_count())
        boundaries = find_chunk_boundaries(fbin, num_chunks, split_token_bytes)
    spans = list(zip(boundaries[:-1], boundaries[1:]))
    pretokenized_counter = Counter()
    if spans:
        cpu = max(1, os.cpu_count() - 2)
        workers = min(len(spans), max(1, cpu))
        if workers > 1:
            with mp.Pool(processes=workers) as pool:
                parts = pool.starmap(
                    count_chunk,
                    [(s, e, input_path, PAT, special_tokens) for s, e in spans],
                )
            for c in parts:
                pretokenized_counter.update(c)
        else:
            # Single span or single worker fallback
            for s, e in spans:
                pretokenized_counter.update(
                    count_chunk(s, e, input_path, PAT, special_tokens)
                )

    pair_counts = Counter()
    pair_index = defaultdict(set)  # pair -> words (token tuples) that contain the pair
    for word_seq, freq in pretokenized_counter.items():
        if len(word_seq) < 2:
            continue
        for pair in zip(word_seq, word_seq[1:]):
            pair_counts[pair] += freq
            pair_index[pair].add(word_seq)

    merges = []
    while len(vocab) < vocab_size:
        # pair_counts = Counter()
        # for word_seq, freq in pretokenized_counter.items():
        #     if len(word_seq) < 2:
        #         continue
        #     for i in range(len(word_seq) - 1):
        #         pair_counts[(word_seq[i], word_seq[i + 1])] += freq

        if not pair_counts:
            break

        # Select most frequent pair; break ties by lexicographically greatest pair
        max_count = max(pair_counts.values())
        candidates = [pair for pair, cnt in pair_counts.items() if cnt == max_count]
        best_pair = max(candidates)

        # Record merge and add merged token to vocab
        merges.append(best_pair)
        merged_token = best_pair[0] + best_pair[1]
        vocab[len(vocab)] = merged_token

        # Find all word sequences that contain the best pair
        affected_words = pair_index.pop(best_pair, set())
        if not affected_words:
            continue

        updates = {}
        for word_seq in affected_words:
            freq = pretokenized_counter.pop(word_seq, 0)
            if freq == 0:
                continue
            # remove old pair contribution
            for pair in zip(word_seq, word_seq[1:]):
                pair_counts[pair] -= freq
                if pair_counts[pair] <= 0:
                    pair_counts.pop(pair, None)
                pair_index[pair].discard(word_seq)
            # merge occurrences of best_pair in word_seq
            merged_seq = []
            i = 0
            while i < len(word_seq):
                if (
                    i + 1 < len(word_seq)
                    and word_seq[i] == best_pair[0]
                    and word_seq[i + 1] == best_pair[1]
                ):
                    merged_seq.append(merged_token)
                    i += 2
                else:
                    merged_seq.append(word_seq[i])
                    i += 1
            merged_seq = tuple(merged_seq)
            updates[merged_seq] = updates.get(merged_seq, 0) + freq

        for w_new, freq in updates.items():
            prev_freq = pretokenized_counter.get(w_new, 0)
            for pair in zip(w_new, w_new[1:]):
                pair_counts[pair] += freq
                pair_index[pair].add(w_new)
            pretokenized_counter[w_new] = freq + prev_freq

    return vocab, merges

## 个人优化

## 对比最终结果

In [59]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280

In [60]:
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)
merges

[(b'h', b'e'),
 (b' ', b'w'),
 (b' w', b'he'),
 (b' ', b't'),
 (b'r', b'e'),
 (b'i', b's'),
 (b' whe', b're'),
 (b' whe', b'n'),
 (b' t', b'he'),
 (b' ', b'y'),
 (b' ', b'is'),
 (b'o', b'u'),
 (b'n', b'n'),
 (b'nn', b'e'),
 (b'nne', b'r'),
 (b'n', b'i'),
 (b'ni', b'c'),
 (b'nic', b'e'),
 (b'm', b'e'),
 (b'me', b'e'),
 (b'mee', b't'),
 (b'l', b'i'),
 (b'i', b'nner')]

## 对比每一步

In [54]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280
string = """hi. i'm yifan li. nice to meet you. <|endoftext|>
this the what when here where. <|endoftext|>
where is the car. <|endoftext|>
when is dinner"""

In [55]:
merges = []
string_list = split_string(string, special_tokens)
word_level_counts = get_tok_counts(string_list)
byte_level_counts = get_byte_counts(word_level_counts)
vocab = initiate_vocab(special_tokens)
vocab_len = len(vocab)
pair_counts = get_pair_counts(byte_level_counts)


In [58]:
sorted_items = sorted(pair_counts.items(), key=lambda x: (x[1], x[0]), reverse=True)
sorted_items

[((119, 257), 4),
 ((257, 114), 3),
 ((114, 101), 3),
 ((105, 115), 3),
 ((32, 119), 3),
 ((32, 116), 3),
 ((32, 105), 3),
 ((257, 110), 2),
 ((116, 257), 2),
 ((104, 105), 2),
 ((32, 121), 2),
 ((121, 111), 1),
 ((121, 105), 1),
 ((119, 104), 1),
 ((116, 111), 1),
 ((116, 104), 1),
 ((111, 117), 1),
 ((110, 110), 1),
 ((110, 105), 1),
 ((110, 101), 1),
 ((109, 101), 1),
 ((108, 105), 1),
 ((105, 110), 1),
 ((105, 102), 1),
 ((105, 99), 1),
 ((104, 97), 1),
 ((102, 97), 1),
 ((101, 116), 1),
 ((101, 114), 1),
 ((101, 101), 1),
 ((100, 105), 1),
 ((99, 101), 1),
 ((99, 97), 1),
 ((97, 116), 1),
 ((97, 114), 1),
 ((97, 110), 1),
 ((39, 109), 1),
 ((32, 257), 1),
 ((32, 110), 1),
 ((32, 109), 1),
 ((32, 108), 1),
 ((32, 100), 1),
 ((32, 99), 1)]

In [57]:

pair = find_max_pair(pair_counts, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
byte_level_counts = update_element_counts(byte_level_counts, pair,new_index)
pair_counts = get_pair_counts(byte_level_counts)
merges.append((vocab[int(index1)], vocab[int(index2)]))
vocab[new_index] = new_token
vocab_len+=1

In [52]:
byte_level_counts

{(104, 105): 1,
 (46,): 5,
 (32, 105): 1,
 (39, 109): 1,
 (32, 121, 105, 102, 97, 110): 1,
 (32, 108, 105): 1,
 (32, 110, 105, 99, 101): 1,
 (32, 116, 111): 1,
 (32, 109, 101, 101, 116): 1,
 (32, 121, 111, 117): 1,
 (32,): 3,
 (32, 116, 104, 105, 115): 1,
 (32, 116, 257): 2,
 (32, 119, 104, 97, 116): 1,
 (32, 119, 257, 110): 2,
 (32, 257, 114, 101): 1,
 (32, 119, 257, 114, 101): 2,
 (32, 105, 115): 2,
 (32, 99, 97, 114): 1,
 (32, 100, 105, 110, 110, 101, 114): 1}

## 王少东的版本

In [23]:
import os
from dataclasses import dataclass
from collections import Counter, defaultdict
import multiprocessing as mp
import regex as re
from typing import BinaryIO, Iterable, Iterator
import time
import pickle
import psutil
import numpy as np

def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(
        split_special_token, bytes
    ), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))


def count_chunk(
    start: int, end: int, path: str, pat: str, specials: list[str]
) -> Counter:
    """Count pre-tokens within a file slice [start, end). Minimal helper for multiprocessing."""
    counter = Counter()
    with open(path, "rb") as fh:
        fh.seek(start)
        raw = fh.read(end - start)
    text = raw.decode("utf-8", errors="ignore")
    # Normalize newlines so Windows CRLF does not introduce stray \r tokens
    # This ensures reproducible tokenization across platforms
    text = text.replace("\r\n", "\n").replace("\r", "")
    specials_set = set(specials)
    if specials:
        split_pat = "|".join(
            re.escape(tok) for tok in specials
        )  # escape special tokens since some have "|" in them
        segments = re.split(split_pat, text)
    else:
        segments = [text]
    for segment in segments:
        if not segment:
            continue
        for match in re.finditer(pat, segment):
            token_text = match.group(0)
            token_bytes = token_text.encode("utf-8")
            seq = tuple(bytes([b]) for b in token_bytes)
            counter[seq] += 1
    return counter

In [49]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280

In [53]:
vocab = {i: bytes([i]) for i in range(256)}
for token in special_tokens:
    vocab[len(vocab)] = token.encode("utf-8")
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
split_token_bytes = b"<|endoftext|>"
with open(input_path, "rb") as fbin:
    num_chunks = 100 # max(1, os.cpu_count())
    boundaries = find_chunk_boundaries(fbin, num_chunks, split_token_bytes)
spans = list(zip(boundaries[:-1], boundaries[1:]))
pretokenized_counter = Counter()
workers = 1
for s, e in spans:
    pretokenized_counter.update(
        count_chunk(s, e, input_path, PAT, special_tokens)
    )
pair_counts = Counter()
pair_index = defaultdict(set)  # pair -> words (token tuples) that contain the pair
for word_seq, freq in pretokenized_counter.items():
    if len(word_seq) < 2:
        continue
    for pair in zip(word_seq, word_seq[1:]):
        pair_counts[pair] += freq
        pair_index[pair].add(word_seq)

In [52]:
pretokenized_counter

Counter({(b'.',): 5,
         (b' ',): 3,
         (b'"',): 2,
         (b' ', b't', b'h', b'e'): 2,
         (b' ', b'w', b'h', b'e', b'n'): 2,
         (b' ', b'w', b'h', b'e', b'r', b'e'): 2,
         (b' ', b'i', b's'): 2,
         (b'h', b'i'): 1,
         (b' ', b'i'): 1,
         (b"'", b'm'): 1,
         (b' ', b'y', b'i', b'f', b'a', b'n'): 1,
         (b' ', b'l', b'i'): 1,
         (b' ', b'n', b'i', b'c', b'e'): 1,
         (b' ', b't', b'o'): 1,
         (b' ', b'm', b'e', b'e', b't'): 1,
         (b' ', b'y', b'o', b'u'): 1,
         (b' ', b't', b'h', b'i', b's'): 1,
         (b' ', b'w', b'h', b'a', b't'): 1,
         (b' ', b'h', b'e', b'r', b'e'): 1,
         (b' ', b'c', b'a', b'r'): 1,
         (b' ', b'd', b'i', b'n', b'n', b'e', b'r'): 1})

In [54]:
pair_counts

Counter({(b'h', b'e'): 7,
         (b' ', b'w'): 5,
         (b'w', b'h'): 5,
         (b' ', b't'): 4,
         (b'e', b'r'): 4,
         (b' ', b'i'): 3,
         (b't', b'h'): 3,
         (b'i', b's'): 3,
         (b'r', b'e'): 3,
         (b'h', b'i'): 2,
         (b' ', b'y'): 2,
         (b'e', b'n'): 2,
         (b"'", b'm'): 1,
         (b'y', b'i'): 1,
         (b'i', b'f'): 1,
         (b'f', b'a'): 1,
         (b'a', b'n'): 1,
         (b' ', b'l'): 1,
         (b'l', b'i'): 1,
         (b' ', b'n'): 1,
         (b'n', b'i'): 1,
         (b'i', b'c'): 1,
         (b'c', b'e'): 1,
         (b't', b'o'): 1,
         (b' ', b'm'): 1,
         (b'm', b'e'): 1,
         (b'e', b'e'): 1,
         (b'e', b't'): 1,
         (b'y', b'o'): 1,
         (b'o', b'u'): 1,
         (b'h', b'a'): 1,
         (b'a', b't'): 1,
         (b' ', b'h'): 1,
         (b' ', b'c'): 1,
         (b'c', b'a'): 1,
         (b'a', b'r'): 1,
         (b' ', b'd'): 1,
         (b'd', b'i'): 1,
         (b'

In [55]:
pair_index

defaultdict(set,
            {(b'h', b'i'): {(b' ', b't', b'h', b'i', b's'), (b'h', b'i')},
             (b' ', b'i'): {(b' ', b'i'), (b' ', b'i', b's')},
             (b"'", b'm'): {(b"'", b'm')},
             (b' ', b'y'): {(b' ', b'y', b'i', b'f', b'a', b'n'),
              (b' ', b'y', b'o', b'u')},
             (b'y', b'i'): {(b' ', b'y', b'i', b'f', b'a', b'n')},
             (b'i', b'f'): {(b' ', b'y', b'i', b'f', b'a', b'n')},
             (b'f', b'a'): {(b' ', b'y', b'i', b'f', b'a', b'n')},
             (b'a', b'n'): {(b' ', b'y', b'i', b'f', b'a', b'n')},
             (b' ', b'l'): {(b' ', b'l', b'i')},
             (b'l', b'i'): {(b' ', b'l', b'i')},
             (b' ', b'n'): {(b' ', b'n', b'i', b'c', b'e')},
             (b'n', b'i'): {(b' ', b'n', b'i', b'c', b'e')},
             (b'i', b'c'): {(b' ', b'n', b'i', b'c', b'e')},
             (b'c', b'e'): {(b' ', b'n', b'i', b'c', b'e')},
             (b' ', b't'): {(b' ', b't', b'h', b'e'),
              (b' ', b't', b'h'

In [45]:
merges = []


max_count = max(pair_counts.values())
candidates = [pair for pair, cnt in pair_counts.items() if cnt == max_count]
best_pair = max(candidates)
# Record merge and add merged token to vocab
merges.append(best_pair)
merged_token = best_pair[0] + best_pair[1]
vocab[len(vocab)] = merged_token

affected_words = pair_index.pop(best_pair, set())
affected_words

{(b' ', b'h', b'e', b'r', b'e'),
 (b' ', b't', b'h', b'e'),
 (b' ', b'w', b'h', b'e', b'n'),
 (b' ', b'w', b'h', b'e', b'r', b'e')}

In [None]:
# Find all word sequences that contain the best pair
affected_words = pair_index.pop(best_pair, set())
if not affected_words:
    continue

updates = {}
for word_seq in affected_words:
    freq = pretokenized_counter.pop(word_seq, 0)
    if freq == 0:
        continue
    # remove old pair contribution
    for pair in zip(word_seq, word_seq[1:]):
        pair_counts[pair] -= freq
        if pair_counts[pair] <= 0:
            pair_counts.pop(pair, None)
        pair_index[pair].discard(word_seq)
    # merge occurrences of best_pair in word_seq
    merged_seq = []
    i = 0
    while i < len(word_seq):
        if (
            i + 1 < len(word_seq)
            and word_seq[i] == best_pair[0]
            and word_seq[i + 1] == best_pair[1]
        ):
            merged_seq.append(merged_token)
            i += 2
        else:
            merged_seq.append(word_seq[i])
            i += 1
    merged_seq = tuple(merged_seq)
    updates[merged_seq] = updates.get(merged_seq, 0) + freq

for w_new, freq in updates.items():
    prev_freq = pretokenized_counter.get(w_new, 0)
    for pair in zip(w_new, w_new[1:]):
        pair_counts[pair] += freq
        pair_index[pair].add(w_new)
    pretokenized_counter[w_new] = freq + prev_freq

In [48]:
pretokenized_counter[word_seq]

1

In [None]:
if not affected_words:
    continue

updates = {}
for word_seq in affected_words:
    freq = pretokenized_counter.pop(word_seq, 0)
    if freq == 0:
        continue
    # remove old pair contribution
    for pair in zip(word_seq, word_seq[1:]):
        pair_counts[pair] -= freq
        if pair_counts[pair] <= 0:
            pair_counts.pop(pair, None)
        pair_index[pair].discard(word_seq)
    # merge occurrences of best_pair in word_seq
    merged_seq = []
    i = 0
    while i < len(word_seq):
        if (
            i + 1 < len(word_seq)
            and word_seq[i] == best_pair[0]
            and word_seq[i + 1] == best_pair[1]
        ):
            merged_seq.append(merged_token)
            i += 2
        else:
            merged_seq.append(word_seq[i])
            i += 1
    merged_seq = tuple(merged_seq)
    updates[merged_seq] = updates.get(merged_seq, 0) + freq

for w_new, freq in updates.items():
    prev_freq = pretokenized_counter.get(w_new, 0)
    for pair in zip(w_new, w_new[1:]):
        pair_counts[pair] += freq
        pair_index[pair].add(w_new)
    pretokenized_counter[w_new] = freq + prev_freq

