In [1]:
import pandas as pd
from collections import defaultdict
import regex as re
from multiprocessing import Pool
from support.find_chunk_boundaries import find_chunk_boundaries
from memory_profiler import profile
import time, tracemalloc
from dataclasses import dataclass

## Q1.1 Problem (train_bpe): BPE Tokenizer Training 


### Version 1.0

In [44]:
from collections import defaultdict
import regex as re


def load_txt_as_str(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_string(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def get_tok_counts(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def get_byte_counts(counts: dict[str, int])-> dict[str, int]:
    element_counts = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        element_counts[elements] += count
    return element_counts

def get_pair_counts(element_counts: dict[str, int]) -> dict[tuple[int,int], int]:
    pair_counts = defaultdict(int)
    for elements, count in element_counts.items():
        for i in range(len(elements)-1):
            pair_counts[(elements[i],elements[i+1])] += count
    return pair_counts


def update_element_counts(byte_level_counts: dict[str, int], pair: tuple[int, int], new_index: int) -> dict[str, int]:
    new_byte_level_counts = {}
    for elements, counts in byte_level_counts.items():
        new_element = []
        elements_len = len(elements)
        index = 0
        while index <= elements_len-1:
            if (index < elements_len-1) and (elements[index] == pair[0]) and (elements[index+1] == pair[1]):
                new_element.append(new_index)
                index += 2
            else:
                new_element.append(elements[index])
                index += 1
        new_byte_level_counts[tuple(new_element)] = counts
    return new_byte_level_counts  

def initiate_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab   

def find_max_pair(pair_counts: dict[tuple[int,int], int], vocab:dict[int, bytes]) -> tuple[int, int]:
    max_count = max(pair_counts.values())
    candidate_pairs = [key for key, value in pair_counts.items() if value == max_count]
    def sort_pair(pair):
        index1, index2 = pair
        return(vocab[index1], vocab[index2])
    pair = max(candidate_pairs, key = sort_pair)
    return pair


def pre_tokenize(string: str,vocab_size: int,special_tokens: list[str]) -> tuple[dict[int, bytes],list[tuple[bytes, bytes]]]:
    merges = []
    string_list = split_string(string, special_tokens)
    word_level_counts = get_tok_counts(string_list)
    byte_level_counts = get_byte_counts(word_level_counts)
    vocab = initiate_vocab(special_tokens)
    vocab_len = len(vocab)

    while vocab_len<vocab_size:
        pair_counts = get_pair_counts(byte_level_counts)
        if len(pair_counts) == 0:
            break
        pair = find_max_pair(pair_counts, vocab)
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        byte_level_counts = update_element_counts(byte_level_counts, pair,new_index)
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        vocab[new_index] = new_token
        vocab_len+=1
    return vocab, merges

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    string = load_txt_as_str(input_path)
    vocab, merges = pre_tokenize(string, vocab_size,special_tokens)
    return vocab, merges

### Version 2.0

In [2]:
Pair = tuple[int,int]
Encoded_Token = tuple[int, ...]
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [3]:

@dataclass(frozen=True)
class TokenMergePlan():
    old_token: Encoded_Token
    new_token: Encoded_Token
    count: int
    pair_positions: list[int]

class PairFreqsDelta():
    inc: defaultdict[Pair, int]
    inc: defaultdict[Pair, int]
    def __init__(self):
        self.inc = defaultdict(int)
        self.dec = defaultdict(int)

class PairInhereitDelta():
    add: defaultdict[Pair,set[Encoded_Token]]
    remove: defaultdict[Pair,set[Encoded_Token]]
    def __init__(self):
        self.add = defaultdict(set[Encoded_Token])
        self.remove = defaultdict(set[Encoded_Token])

In [4]:
def read_text_file(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_by_special_tokens(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

def count_tokens(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def encode_and_count_tokens(counts: dict[str, int])-> dict[Encoded_Token, int]:
    encoded_token_freqs = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        encoded_token_freqs[elements] += count
    return encoded_token_freqs

def build_initial_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab

def get_byte_pairs(encoded_token_freqs: dict[Encoded_Token, int]) -> dict[Pair, int]:
    pair_freqs = defaultdict(int)
    for tok, count in encoded_token_freqs.items():
        for i in range(len(tok)-1):
            pair_freqs[(tok[i],tok[i+1])] += count
    return pair_freqs

def get_byte_pairs_inhereit(encoded_token_freqs: dict[Encoded_Token, int]) -> dict[Pair, set[Encoded_Token]]:
    pair_inhereit = defaultdict(set)
    for tok, count in encoded_token_freqs.items():
        for i in range(len(tok)-1):
            pair_inhereit[(tok[i],tok[i+1])].add(tok)
    return pair_inhereit

def select_merge_pair(pair_freqs: dict[Pair, int], vocab:dict[int, bytes]) -> Pair:
    max_count = max(pair_freqs.values())
    candidate_pairs = [key for key, value in pair_freqs.items() if value == max_count]
    def sort_pair(pair):
        index1, index2 = pair
        return(vocab[index1], vocab[index2])
    pair = max(candidate_pairs, key = sort_pair)
    return pair

def update_encoded_token(encoded_token: Encoded_Token, pair: Pair, new_index: int) -> Encoded_Token:
    result = []
    i = 0
    while i < len(encoded_token):
        if i < len(encoded_token) - 1 and (encoded_token[i], encoded_token[i + 1]) == pair:
            result.append(new_index)
            i += 2
        else:
            result.append(encoded_token[i])
            i += 1
    return tuple(result)

def find_subtuple_index(sequence: tuple, subseq: tuple) -> list[int]:
    position = []
    subseq_len = len(subseq)
    for i in range(len(sequence)-subseq_len+1):
        if sequence[i:i+subseq_len] == subseq:
            position.append(i)
    return position

def remove_or_decrement_pair(pair_freqs: dict[Pair, int], pair: Pair, count: int) -> dict[Pair, int]:
    updated_pair_freqs = pair_freqs.copy()
    if updated_pair_freqs[pair] == count:
        del updated_pair_freqs[pair]
    else:
        updated_pair_freqs[pair] -= count
    return updated_pair_freqs

def build_merge_plan(tok_need_update: set[Encoded_Token], encoded_token_freqs: dict[Encoded_Token, int], pair: Pair, new_index: int) -> list[TokenMergePlan]:
    plan = []
    for encoded_token in tok_need_update:
        new_encoded_token = update_encoded_token(encoded_token, pair, new_index)
        count = encoded_token_freqs[encoded_token]
        pair_positions = find_subtuple_index(encoded_token,pair)
        plan.append(TokenMergePlan(encoded_token,new_encoded_token,count,pair_positions))
    return plan

def update_encoded_token_freqs(plan: list[TokenMergePlan], encoded_token_freqs: dict[Encoded_Token, int]) ->  dict[Encoded_Token, int]:
    new_encoded_token_freqs = encoded_token_freqs.copy()
    for item in plan:
        del new_encoded_token_freqs[item.old_token]
        new_encoded_token_freqs[item.new_token] = item.count
    return new_encoded_token_freqs

def compute_freqs_deltas(plan: list[TokenMergePlan], new_index:int) -> PairFreqsDelta:
    pair_freqs_d = PairFreqsDelta()
    for item in plan:
        old_token = item.old_token
        count = item.count
        for pos in item.pair_positions:
            if pos > 0:
                pre_token = old_token[pos-1]
                old_pair = (pre_token,old_token[pos])
                new_pair = (pre_token, new_index)
                pair_freqs_d.dec[old_pair] += count
                pair_freqs_d.inc[new_pair] += count
            if pos < len(old_token)-2:
                pos_token = old_token[pos+2]
                old_pair = (old_token[pos+1],pos_token)
                new_pair = (new_index, pos_token)
                pair_freqs_d.dec[old_pair] += count
                pair_freqs_d.inc[new_pair] += count
    return pair_freqs_d

def compute_inhereit_deltas(plan: list[TokenMergePlan]) -> PairInhereitDelta:
    pair_inhereit_d = PairInhereitDelta()
    for item in plan:
        if len(item.old_token) > 1:
            for old_pair in zip(item.old_token,item.old_token[1:]):
                pair_inhereit_d.remove[old_pair].add(item.old_token)
        if len(item.new_token) > 1:
            for new_pair in zip(item.new_token,item.new_token[1:]):
                pair_inhereit_d.add[new_pair].add(item.new_token)
    return pair_inhereit_d 

def exclude_pair_from_dict(d: dict, pair: Pair):
    new_d = d.copy()
    del new_d[pair]
    return new_d

def update_pair_freqs(pair_freqs: dict[Pair, int], pair_freqs_d: PairFreqsDelta):
    new_pair_freqs = pair_freqs.copy()
    for key, value in pair_freqs_d.dec.items():
        new_pair_freqs = remove_or_decrement_pair(new_pair_freqs, key, value)
    for key, value in pair_freqs_d.inc.items():
        new_pair_freqs[key]+=value
    return new_pair_freqs

def update_pair_inhereit(pair_inhereit: dict[Pair, set[Encoded_Token]], pair_inhereit_d: PairInhereitDelta):
    new_pair_inhereit = pair_inhereit.copy()
    for key, value in pair_inhereit_d.remove.items():
        new_pair_inhereit[key] -= value
    for key, value in pair_inhereit_d.add.items():
        new_pair_inhereit[key] = new_pair_inhereit[key] | value
    return new_pair_inhereit

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    vocab = build_initial_vocab(special_tokens)
    vocab_len = len(vocab)
    merges = []

    string = read_text_file(input_path)
    string_list = split_by_special_tokens(string,special_tokens)
    token_count = count_tokens(string_list)
    encoded_token_freqs = encode_and_count_tokens(token_count)
    
    pair_freqs = get_byte_pairs(encoded_token_freqs)
    pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

    while vocab_len < vocab_size:
        pair = select_merge_pair(pair_freqs, vocab)
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        vocab[new_index] = new_token
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        tok_need_update = pair_inhereit[pair]
        plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
        encoded_token_freqs = update_encoded_token_freqs(plan, encoded_token_freqs)
        pair_freqs_d = compute_freqs_deltas(plan, new_index)
        pair_freqs = exclude_pair_from_dict(pair_freqs, pair)
        pair_freqs = update_pair_freqs(pair_freqs, pair_freqs_d)
        pair_inhereit_d = compute_inhereit_deltas(plan)
        pair_inhereit = exclude_pair_from_dict(pair_inhereit, pair)
        pair_inhereit = update_pair_inhereit(pair_inhereit, pair_inhereit_d)
        vocab_len+=1
    return vocab, merges

## 持续优化

In [6]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280

In [7]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

string = read_text_file(input_path)
string_list = split_by_special_tokens(string,special_tokens)
token_count = count_tokens(string_list)
encoded_token_freqs = encode_and_count_tokens(token_count)

pair_freqs = get_byte_pairs(encoded_token_freqs)
pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

In [10]:
## while
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)

In [5]:
import pickle

with open("./data/TinyStoriesV2-GPT4-train_10000_vocab.pkl", "rb") as f:
    vocab = pickle.load(f)


In [14]:
import heapq

def top_n_longest_byte_values(d: dict, n: int):
    return heapq.nlargest(n, d.items(), key=lambda x: len(x[1]))

In [3]:
top_n_longest_byte_values(vocab,20)

[(7172, b' accomplishment'),
 (9156, b' disappointment'),
 (9393, b' responsibility'),
 (3236, b' uncomfortable'),
 (3524, b' compassionate'),
 (5327, b' understanding'),
 (6401, b' neighbourhood'),
 (6512, b' Unfortunately'),
 (6888, b' determination'),
 (7769, b' encouragement'),
 (8638, b' unfortunately'),
 (8711, b' congratulated'),
 (8881, b' extraordinary'),
 (9108, b' granddaughter'),
 (256, b'<|endoftext|>'),
 (3346, b' disappointed'),
 (3777, b' enthusiastic'),
 (4368, b' accidentally'),
 (4383, b' refrigerator'),
 (4479, b' veterinarian')]

In [14]:
top_n_longest_byte_values(vocab,20)

[(7172, b' accomplishment'),
 (9156, b' disappointment'),
 (9393, b' responsibility'),
 (3236, b' uncomfortable'),
 (3524, b' compassionate'),
 (5327, b' understanding'),
 (6401, b' neighbourhood'),
 (6512, b' Unfortunately'),
 (6888, b' determination'),
 (7769, b' encouragement'),
 (8638, b' unfortunately'),
 (8711, b' congratulated'),
 (8881, b' extraordinary'),
 (9108, b' granddaughter'),
 (256, b'<|endoftext|>'),
 (3346, b' disappointed'),
 (3777, b' enthusiastic'),
 (4368, b' accidentally'),
 (4383, b' refrigerator'),
 (4479, b' veterinarian')]

# 个人优化

## 对比最终结果

In [5]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280

In [7]:
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)
merges

[(b'h', b'e'),
 (b' ', b'w'),
 (b' w', b'he'),
 (b' ', b't'),
 (b'r', b'e'),
 (b'i', b's'),
 (b' whe', b're'),
 (b' whe', b'n'),
 (b' t', b'he'),
 (b' ', b'y'),
 (b' ', b'is'),
 (b'o', b'u'),
 (b'n', b'n'),
 (b'nn', b'e'),
 (b'nne', b'r'),
 (b'n', b'i'),
 (b'ni', b'c'),
 (b'nic', b'e'),
 (b'm', b'e'),
 (b'me', b'e'),
 (b'mee', b't'),
 (b'l', b'i'),
 (b'i', b'nner')]

## 对比每一步

In [23]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280
# string = """hi. i'm yifan li. nice to meet you. <|endoftext|>
# this the what when here where. <|endoftext|>
# where is the car. <|endoftext|>
# when is dinner"""

In [27]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

string = read_text_file(input_path)
string = string.replace("\r\n", "\n").replace("\r", "")
string_list = split_by_special_tokens(string,special_tokens)
token_count = count_tokens(string_list)
encoded_token_freqs = encode_and_count_tokens(token_count)

pair_freqs = get_byte_pairs(encoded_token_freqs)
pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

In [28]:
sorted_items = sorted(token_count.items(), key=lambda x: x[1], reverse=True)
sorted_items

[('.', 5),
 (' ', 3),
 ('"', 2),
 (' the', 2),
 (' when', 2),
 (' where', 2),
 (' is', 2),
 ('hi', 1),
 (' i', 1),
 ("'m", 1),
 (' yifan', 1),
 (' li', 1),
 (' nice', 1),
 (' to', 1),
 (' meet', 1),
 (' you', 1),
 (' this', 1),
 (' what', 1),
 (' here', 1),
 (' car', 1),
 (' dinner', 1)]

In [29]:
sorted_items = sorted(pair_freqs.items(), key=lambda x: (x[1], x[0]), reverse=True)
sorted_items

[((104, 101), 7),
 ((119, 104), 5),
 ((32, 119), 5),
 ((101, 114), 4),
 ((32, 116), 4),
 ((116, 104), 3),
 ((114, 101), 3),
 ((105, 115), 3),
 ((32, 105), 3),
 ((104, 105), 2),
 ((101, 110), 2),
 ((32, 121), 2),
 ((121, 111), 1),
 ((121, 105), 1),
 ((116, 111), 1),
 ((111, 117), 1),
 ((110, 110), 1),
 ((110, 105), 1),
 ((110, 101), 1),
 ((109, 101), 1),
 ((108, 105), 1),
 ((105, 110), 1),
 ((105, 102), 1),
 ((105, 99), 1),
 ((104, 97), 1),
 ((102, 97), 1),
 ((101, 116), 1),
 ((101, 101), 1),
 ((100, 105), 1),
 ((99, 101), 1),
 ((99, 97), 1),
 ((97, 116), 1),
 ((97, 114), 1),
 ((97, 110), 1),
 ((39, 109), 1),
 ((32, 110), 1),
 ((32, 109), 1),
 ((32, 108), 1),
 ((32, 104), 1),
 ((32, 100), 1),
 ((32, 99), 1)]

In [30]:
pair_inhereit

defaultdict(set,
            {(104, 105): {(32, 116, 104, 105, 115), (104, 105)},
             (32, 105): {(32, 105), (32, 105, 115)},
             (39, 109): {(39, 109)},
             (32, 121): {(32, 121, 105, 102, 97, 110), (32, 121, 111, 117)},
             (121, 105): {(32, 121, 105, 102, 97, 110)},
             (105, 102): {(32, 121, 105, 102, 97, 110)},
             (102, 97): {(32, 121, 105, 102, 97, 110)},
             (97, 110): {(32, 121, 105, 102, 97, 110)},
             (32, 108): {(32, 108, 105)},
             (108, 105): {(32, 108, 105)},
             (32, 110): {(32, 110, 105, 99, 101)},
             (110, 105): {(32, 110, 105, 99, 101)},
             (105, 99): {(32, 110, 105, 99, 101)},
             (99, 101): {(32, 110, 105, 99, 101)},
             (32, 116): {(32, 116, 104, 101),
              (32, 116, 104, 105, 115),
              (32, 116, 111)},
             (116, 111): {(32, 116, 111)},
             (32, 109): {(32, 109, 101, 101, 116)},
             (109, 101)

In [37]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)

In [40]:
pair_freqs_d = compute_freqs_deltas(plan, new_index)
pair_freqs_d

<__main__.PairFreqsDelta at 0x1278624a0>

In [41]:
pair_freqs_d.dec

defaultdict(int,
            {(119, 104): 4,
             (101, 110): 2,
             (32, 104): 1,
             (101, 114): 3,
             (116, 104): 2})

In [92]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
encoded_token_freqs = update_encoded_token_freqs(plan, encoded_token_freqs)
pair_freqs_d = compute_freqs_deltas(plan, new_index)
pair_freqs = exclude_pair_from_dict(pair_freqs, pair)
pair_freqs = update_pair_freqs(pair_freqs, pair_freqs_d)
pair_inhereit_d = compute_inhereit_deltas(plan)
pair_inhereit = exclude_pair_from_dict(pair_inhereit, pair)
pair_inhereit = update_pair_inhereit(pair_inhereit, pair_inhereit_d)
vocab_len+=1

In [51]:
encoded_token_freqs

defaultdict(int,
            {(104, 105): 1,
             (46,): 5,
             (32, 105): 1,
             (39, 109): 1,
             (32, 121, 105, 102, 97, 110): 1,
             (32, 108, 105): 1,
             (32, 110, 105, 99, 101): 1,
             (32, 116, 111): 1,
             (32, 109, 101, 101, 116): 1,
             (32, 121, 111, 117): 1,
             (32,): 3,
             (32, 116, 104, 105, 115): 1,
             (32, 119, 104, 97, 116): 1,
             (32, 105, 115): 2,
             (32, 99, 97, 114): 1,
             (32, 100, 105, 110, 110, 101, 114): 1,
             (32, 119, 257, 110): 2,
             (32, 257, 114, 101): 1,
             (32, 116, 257): 2,
             (32, 119, 257, 114, 101): 2})

In [None]:
pair_freqs

## debug

In [82]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280
string = """hi. i'm yifan li. nice to meet you. <|endoftext|>
this the what when here where. <|endoftext|>
where is the car. <|endoftext|>
when is dinner"""

In [83]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

# string = read_text_file(input_path)
string_list = split_by_special_tokens(string,special_tokens)
token_count = count_tokens(string_list)
encoded_token_freqs = encode_and_count_tokens(token_count)

pair_freqs = get_byte_pairs(encoded_token_freqs)
pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

In [84]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
encoded_token_freqs = update_encoded_token_freqs(plan, encoded_token_freqs)
pair_freqs_d = compute_freqs_deltas(plan, new_index)
# pair_freqs = exclude_pair_from_dict(pair_freqs, pair)
# pair_freqs = update_pair_freqs(pair_freqs, pair_freqs_d)
# pair_inhereit_d = compute_inhereit_deltas(plan)
# pair_inhereit = exclude_pair_from_dict(pair_inhereit, pair)
# pair_inhereit = update_pair_inhereit(pair_inhereit, pair_inhereit_d)
# vocab_len+=1

In [85]:
plan

[TokenMergePlan(old_token=(119, 104, 101, 114, 101), new_token=(119, 257, 114, 101), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 119, 104, 101, 114, 101), new_token=(32, 119, 257, 114, 101), count=1, pair_positions=[2]),
 TokenMergePlan(old_token=(119, 104, 101, 110), new_token=(119, 257, 110), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 119, 104, 101, 110), new_token=(32, 119, 257, 110), count=1, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 104, 101, 114, 101), new_token=(32, 257, 114, 101), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 116, 104, 101), new_token=(32, 116, 257), count=2, pair_positions=[2])]

In [86]:
pair_freqs_d.inc

defaultdict(int,
            {(119, 257): 4,
             (257, 114): 3,
             (257, 110): 2,
             (32, 257): 1,
             (116, 257): 2})

In [None]:
def build_merge_plan(tok_need_update: set[Encoded_Token], encoded_token_freqs: dict[Encoded_Token, int], pair: Pair, new_index: int) -> list[TokenMergePlan]:
    plan = []
    for encoded_token in tok_need_update:
        new_encoded_token = update_encoded_token(encoded_token, pair, new_index)
        count = encoded_token_freqs[encoded_token]
        pair_positions = find_subtuple_index(encoded_token,pair)
        plan.append(TokenMergePlan(encoded_token,new_encoded_token,count,pair_positions))
    return plan

In [21]:
def find_diff(A,B):
    # 1. A 有但 B 没有的 key 和值
    only_in_A = {k: A[k] for k in A.keys() - B.keys()}

    # 2. B 有但 A 没有的 key 和值
    only_in_B = {k: B[k] for k in B.keys() - A.keys()}

    # 3. A 和 B 都有，但值不同的 key 和值
    diff_values = {k: (A[k], B[k]) for k in A.keys() & B.keys() if A[k] != B[k]}
    return only_in_A, only_in_B, diff_values

In [30]:
plan

[TokenMergePlan(old_token=(32, 119, 104, 101, 110), new_token=(32, 119, 257, 110), count=2, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 104, 101, 114, 101), new_token=(32, 257, 114, 101), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 116, 104, 101), new_token=(32, 116, 257), count=2, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 119, 104, 101, 114, 101), new_token=(32, 119, 257, 114, 101), count=2, pair_positions=[2])]

In [22]:
only_in_A, only_in_B, diff_values = find_diff(pair_inhereit,pair_inhereit2)

In [29]:
only_in_A

{}

In [26]:
only_in_B

{(257, 114): {(32, 119, 257, 114, 101), (32, 257, 114, 101)},
 (116, 257): {(32, 116, 257)},
 (257, 110): {(32, 119, 257, 110)},
 (119, 257): {(32, 119, 257, 110), (32, 119, 257, 114, 101)},
 (32, 257): {(32, 257, 114, 101)}}

In [27]:
diff_values

{(32, 116): ({(32, 116, 104, 105, 115), (32, 116, 111)},
  {(32, 116, 104, 105, 115), (32, 116, 111), (32, 116, 257)}),
 (32, 119): ({(32, 119, 104, 97, 116)},
  {(32, 119, 104, 97, 116), (32, 119, 257, 110), (32, 119, 257, 114, 101)}),
 (104,
  101): ({(32, 104, 101, 114, 101),
   (32, 116, 104, 101),
   (32, 119, 104, 101, 110),
   (32, 119, 104, 101, 114, 101)}, set()),
 (114, 101): (set(), {(32, 119, 257, 114, 101), (32, 257, 114, 101)})}

In [36]:
pair_inhereit[(32,116)]

{(32, 116, 104, 101), (32, 116, 104, 105, 115), (32, 116, 111)}

In [43]:
pair_inhereit2[(32,119)]

{(32, 119, 104, 97, 116), (32, 119, 257, 110), (32, 119, 257, 114, 101)}

## Unit test

In [45]:
input_path = r"./tests/fixtures/tinystories_sample_5M.txt"

In [48]:
vocab_size = 1000
special_tokens=["<|endoftext|>"]
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

In [49]:
import pickle

with open("./tests/_snapshots/test_train_bpe_special_tokens.pkl", "rb") as f:
    data = pickle.load(f)
    
data.keys()

dict_keys(['vocab_keys', 'vocab_values', 'merges'])

In [50]:
for i, (my_merge, target_merge) in enumerate(zip(merges,data['merges'])):
    if not my_merge == target_merge:
        print([i, my_merge, target_merge])
        break

In [12]:
merges

[(b'h', b'e'),
 (b' ', b't'),
 (b' ', b'a'),
 (b't', b'h'),
 (b' ', b's'),
 (b'a', b'n'),
 (b' ', b'w'),
 (b'n', b'd'),
 (b' ', b'h'),
 (b'e', b'd'),
 (b't', b'o'),
 (b'e', b'r'),
 (b' ', b'b'),
 (b'i', b'n'),
 (b'w', b'a'),
 (b' ', b'f'),
 (b'r', b'e'),
 (b' ', b'T'),
 (b'i', b't'),
 (b'o', b'u'),
 (b'a', b's'),
 (b'h', b'a'),
 (b' ', b'l'),
 (b' ', b'd'),
 (b' ', b'i'),
 (b'e', b'n'),
 (b' ', b'c'),
 (b' ', b'p'),
 (b'a', b'y'),
 (b'a', b'r'),
 (b'm', b'e'),
 (b' ', b'm'),
 (b'o', b'm'),
 (b'v', b'e'),
 (b'a', b't'),
 (b' ', b'o'),
 (b'l', b'e'),
 (b's', b'a'),
 (b'h', b'i'),
 (b'n', b'e'),
 (b'n', b'g'),
 (b' ', b'n'),
 (b'i', b'm'),
 (b'l', b'l'),
 (b't', b'e'),
 (b'i', b'd'),
 (b'a', b'l'),
 (b'k', b'e'),
 (b'T', b'h'),
 (b'o', b'r'),
 (b'i', b's'),
 (b' ', b'g'),
 (b' ', b'S'),
 (b's', b't'),
 (b'e', b'a'),
 (b'c', b'a'),
 (b'o', b'o'),
 (b'l', b'a'),
 (b'l', b'i'),
 (b'l', b'o'),
 (b'i', b'l'),
 (b'o', b't'),
 (b'n', b't'),
 (b'a', b'i'),
 (b'a', b'd'),
 (b'o', b'n'),
 (b'u', b'

In [13]:
data['merges']

[(b'h', b'e'),
 (b' ', b't'),
 (b' ', b'a'),
 (b' ', b's'),
 (b' ', b'w'),
 (b'n', b'd'),
 (b' t', b'he'),
 (b'e', b'd'),
 (b' ', b'b'),
 (b' t', b'o'),
 (b' a', b'nd'),
 (b' ', b'h'),
 (b' ', b'f'),
 (b'i', b'n'),
 (b' w', b'a'),
 (b' ', b'T'),
 (b'i', b't'),
 (b'r', b'e'),
 (b'o', b'u'),
 (b' ', b'l'),
 (b' ', b'd'),
 (b' ', b'c'),
 (b' ', b'p'),
 (b'a', b'y'),
 (b' wa', b's'),
 (b'e', b'r'),
 (b' ', b'm'),
 (b'o', b'm'),
 (b' ', b'he'),
 (b' T', b'he'),
 (b'i', b's'),
 (b' ', b'n'),
 (b'o', b'n'),
 (b'a', b'r'),
 (b'i', b'm'),
 (b' s', b'a'),
 (b'l', b'l'),
 (b'i', b'd'),
 (b' h', b'a'),
 (b' ', b'g'),
 (b' ', b'S'),
 (b'a', b't'),
 (b'in', b'g'),
 (b'o', b't'),
 (b'e', b'n'),
 (b'a', b'n'),
 (b'l', b'e'),
 (b'o', b'r'),
 (b'i', b'r'),
 (b' ', b'H'),
 (b'a', b'm'),
 (b'e', b't'),
 (b' ', b'it'),
 (b' t', b'h'),
 (b'i', b'g'),
 (b' The', b'y'),
 (b'i', b'l'),
 (b' ', b'in'),
 (b' H', b'e'),
 (b' p', b'l'),
 (b' ', b'"'),
 (b'o', b'w'),
 (b'v', b'er'),
 (b'r', b'i'),
 (b' ', b'u'),
 (

In [30]:
input_path = r"./tests/fixtures/tinystories_sample_5M.txt"
vocab_size=1000
special_tokens=["<|endoftext|>"]

In [31]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

string = read_text_file(input_path)
string_list = split_by_special_tokens(string,special_tokens)
token_count = count_tokens(string_list)
encoded_token_freqs = encode_and_count_tokens(token_count)

pair_freqs = get_byte_pairs(encoded_token_freqs)
pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

In [None]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
encoded_token_freqs = update_encoded_token_freqs(plan, encoded_token_freqs)
pair_freqs_d, pair_inhereit_d = compute_freqs_and_inhereit_deltas(plan, new_index)
pair_freqs = exclude_pair_from_dict(pair_freqs, pair)
pair_freqs = update_pair_freqs(pair_freqs, pair_freqs_d)
pair_inhereit = exclude_pair_from_dict(pair_inhereit, pair)
pair_inhereit = update_pair_inhereit(pair_inhereit, pair_inhereit_d)
vocab_len+=1

In [33]:
tok_need_update

{(32, 65, 110, 111, 116, 104, 101, 114),
 (32, 66, 114, 111, 116, 104, 101, 114),
 (32, 67, 104, 101, 102),
 (32, 67, 104, 101, 114, 114, 121),
 (32, 69, 118, 101, 114, 121, 119, 104, 101, 114, 101),
 (32, 70, 117, 114, 116, 104, 101, 114, 109, 111, 114, 101),
 (32, 76, 111, 110, 103, 104, 101, 97, 100),
 (32, 78, 101, 105, 116, 104, 101, 114),
 (32, 79, 116, 104, 101, 114),
 (32, 82, 97, 99, 104, 101, 108),
 (32, 83, 104, 101),
 (32, 83, 104, 101, 108, 108),
 (32, 84, 104, 101),
 (32, 84, 104, 101, 105, 114),
 (32, 84, 104, 101, 110),
 (32, 84, 104, 101, 114, 101),
 (32, 84, 104, 101, 115, 101),
 (32, 84, 104, 101, 121),
 (32, 84, 111, 103, 101, 116, 104, 101, 114),
 (32, 87, 104, 101, 101),
 (32, 87, 104, 101, 110),
 (32, 87, 104, 101, 110, 101, 118, 101, 114),
 (32, 87, 104, 101, 114, 101),
 (32, 97, 97, 104, 101, 100),
 (32, 97, 99, 99, 111, 109, 112, 108, 105, 115, 104, 101, 100),
 (32, 97, 99, 104, 101),
 (32, 97, 104, 101, 97, 100),
 (32, 97, 110, 111, 116, 104, 101, 114),
 (32,

In [25]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]

In [28]:
pair_freqs[pair]

101570

In [29]:
for key, value in pair_freqs.items():
    if value>= 101570:
        print(key)

(116, 104)


## Q1.2 Problem (train_bpe_tinystories): BPE Training on TinyStories 

In [None]:
%load_ext memory_profiler

In [14]:
import time

In [21]:
input_path = r"./data/TinyStoriesV2-GPT4-valid.txt"
special_tokens = ['<|endoftext|>']
vocab_size = 32000

In [22]:
start_time = time.time()
%memit vocab, merges = train_bpe(input_path, vocab_size, special_tokens)
end_time = time.time()
print(f"⏱️ 运行时间: {end_time - start_time:.2f} 秒")

peak memory: 356.20 MiB, increment: 23.55 MiB
⏱️ 运行时间: 85.28 秒


In [25]:
input_path = r'./data/TinyStoriesV2-GPT4-train.txt'

extracted

'TinyStoriesV2-GPT4-train'