In [None]:
import pandas as pd
from collections import defaultdict
import regex as re
from multiprocessing import Pool
from support.find_chunk_boundaries import find_chunk_boundaries
from memory_profiler import profile
import time, tracemalloc
from dataclasses import dataclass

## Q1.1 Problem (train_bpe): BPE Tokenizer Training 


In [2]:
Pair = tuple[int,int]
Encoded_Token = tuple[int, ...]
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [3]:
@dataclass(frozen=True)
class TokenMergePlan():
    old_token: Encoded_Token
    new_token: Encoded_Token
    count: int
    pair_positions: list[int]

class PairFreqsDelta():
    inc: defaultdict[Pair, int]
    inc: defaultdict[Pair, int]
    def __init__(self):
        self.inc = defaultdict(int)
        self.dec = defaultdict(int)

class PairInhereitDelta():
    add: defaultdict[Pair,Encoded_Token]
    remove: defaultdict[Pair,Encoded_Token]
    def __init__(self):
        self.add = defaultdict(Encoded_Token)
        self.remove = defaultdict(Encoded_Token)

In [None]:
def read_text_file(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_by_special_tokens(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

def count_tokens(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def encode_and_count_tokens(counts: dict[str, int])-> dict[Encoded_Token, int]:
    encoded_token_freqs = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        encoded_token_freqs[elements] += count
    return encoded_token_freqs

def build_initial_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab

def get_byte_pairs(encoded_token_freqs: dict[Encoded_Token, int]) -> dict[Pair, int]:
    pair_freqs = defaultdict(int)
    for tok, count in encoded_token_freqs.items():
        for i in range(len(tok)-1):
            pair_freqs[(tok[i],tok[i+1])] += count
    return pair_freqs

def get_byte_pairs_inhereit(encoded_token_freqs: dict[Encoded_Token, int]) -> dict[Pair, set[Encoded_Token]]:
    pair_inhereit = defaultdict(set)
    for tok, count in encoded_token_freqs.items():
        for i in range(len(tok)-1):
            pair_inhereit[(tok[i],tok[i+1])].add(tok)
    return pair_inhereit

def select_merge_pair(pair_freqs: dict[Pair, int], vocab:dict[int, bytes]) -> Pair:
    max_count = max(pair_freqs.values())
    candidate_pairs = [key for key, value in pair_freqs.items() if value == max_count]
    def sort_pair(pair):
        index1, index2 = pair
        return(vocab[index1], vocab[index2])
    pair = max(candidate_pairs, key = sort_pair)
    return pair

def update_encoded_token(encoded_token: Encoded_Token, pair: Pair, new_index: int) -> Encoded_Token:
    result = []
    i = 0
    while i < len(encoded_token):
        if i < len(encoded_token) - 1 and (encoded_token[i], encoded_token[i + 1]) == pair:
            result.append(new_index)
            i += 2
        else:
            result.append(encoded_token[i])
            i += 1
    return tuple(result)

def find_subtuple_index(sequence: tuple, subseq: tuple) -> list[int]:
    position = []
    subseq_len = len(subseq)
    for i in range(len(sequence)-subseq_len+1):
        if sequence[i:i+subseq_len] == subseq:
            position.append(i)
    return position

def remove_or_decrement_pair(pair_freqs: dict[Pair, int], pair: Pair, count: int) -> dict[Pair, int]:
    updated_pair_freqs = pair_freqs.copy()
    if updated_pair_freqs[pair] == count:
        del updated_pair_freqs[pair]
    else:
        updated_pair_freqs[pair] -= count
    return updated_pair_freqs

def build_merge_plan(tok_need_update: set[Encoded_Token], encoded_token_freqs: dict[Encoded_Token, int], pair: Pair, new_index: int) -> list[TokenMergePlan]:
    plan = []
    for encoded_token in tok_need_update:
        new_encoded_token = update_encoded_token(encoded_token, pair, new_index)
        count = encoded_token_freqs[encoded_token]
        pair_positions = find_subtuple_index(encoded_token,pair)
        plan.append(TokenMergePlan(encoded_token,new_encoded_token,count,pair_positions))
    return plan

def update_encoded_token_freqs(plan: list[TokenMergePlan], encoded_token_freqs: dict[Encoded_Token, int]) ->  dict[Encoded_Token, int]:
    new_encoded_token_freqs = encoded_token_freqs.copy()
    for item in plan:
        del new_encoded_token_freqs[item.old_token]
        new_encoded_token_freqs[item.new_token] = item.count
    return new_encoded_token_freqs

def compute_freqs_and_inhereit_deltas(plan: list[TokenMergePlan], new_index:int) -> tuple[PairFreqsDelta, PairInhereitDelta]:
    pair_freqs_d = PairFreqsDelta()
    pair_inhereit_d = PairInhereitDelta()
    for item in plan:
        old_token = item.old_token
        new_token = item.new_token
        count = item.count

        for pos in item.pair_positions:
            if pos > 0:
                pre_token = old_token[pos-1]
                old_pair = (pre_token,old_token[pos])
                new_pair = (pre_token, new_index)
                pair_freqs_d.dec[old_pair] = count
                pair_freqs_d.inc[new_pair] = count
                pair_inhereit_d.add[new_pair] = new_token
                pair_inhereit_d.remove[old_pair] = old_token
            if pos < len(old_token)-2:
                pos_token = old_token[pos+2]
                old_pair = (old_token[pos+1],pos_token)
                new_pair = (new_index, pos_token)
                pair_freqs_d.dec[old_pair] = count
                pair_freqs_d.inc[new_pair] = count
                pair_inhereit_d.add[new_pair] = new_token
                pair_inhereit_d.remove[old_pair] = old_token
    return pair_freqs_d, pair_inhereit_d

def exclude_pair_from_dict(d: dict, pair: Pair):
    new_d = d.copy()
    del new_d[pair]
    return new_d

def update_pair_freqs(pair_freqs: dict[Pair, int], pair_freqs_d: PairFreqsDelta):
    new_pair_freqs = pair_freqs.copy()
    for key, value in pair_freqs_d.dec.items():
        new_pair_freqs = remove_or_decrement_pair(new_pair_freqs, key, value)
    for key, value in pair_freqs_d.inc.items():
        new_pair_freqs[key]+=value
    return new_pair_freqs

def update_pair_inhereit(pair_inhereit: dict[Pair, set[Encoded_Token]], pair_inhereit_d: PairInhereitDelta):
    new_pair_inhereit = pair_inhereit.copy()
    for key, value in pair_inhereit_d.remove.items():
        new_pair_inhereit[key].discard(value)
    for key, value in pair_inhereit_d.add.items():
        new_pair_inhereit[key].add(value)
    return new_pair_inhereit

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    vocab = build_initial_vocab(special_tokens)
    vocab_len = len(vocab)
    merges = []

    string = read_text_file(input_path)
    string_list = split_by_special_tokens(string,special_tokens)
    token_count = count_tokens(string_list)
    encoded_token_freqs = encode_and_count_tokens(token_count)
    
    pair_freqs = get_byte_pairs(encoded_token_freqs)
    pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

    while vocab_len < vocab_size:
        pair = select_merge_pair(pair_freqs, vocab)
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        vocab[new_index] = new_token
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        tok_need_update = pair_inhereit[pair]
        plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
        encoded_token_freqs = update_encoded_token_freqs(plan, encoded_token_freqs)
        pair_freqs_d, pair_inhereit_d = compute_freqs_and_inhereit_deltas(plan, new_index)
        pair_freqs = exclude_pair_from_dict(pair_freqs, pair)
        pair_freqs = update_pair_freqs(pair_freqs, pair_freqs_d)
        pair_inhereit = exclude_pair_from_dict(pair_inhereit, pair)
        pair_inhereit = update_pair_inhereit(pair_inhereit, pair_inhereit_d)
        vocab_len+=1
        
    return vocab, merges

In [5]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'

epochs = 6
vocab_size = 270
string = "hi. i'm yifan li. nice to meet you.<|endoftext|> this the what when here where"
# input_path = r'./data/TinyStoriesV2-GPT4-valid.txt'
# string = read_text_file(input_path)

In [516]:
%memit vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

peak memory: 861.81 MiB, increment: 20.62 MiB


## 个人优化

In [18]:
vocab_size = 280
string = "hi. i'm yifan li. nice to meet you. <|endoftext|> this the what when here where. <|endoftext|> where is the car. <|endoftext|> when is dinner"
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'

In [19]:
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

In [20]:
merges

[(b'h', b'e'),
 (b' ', b'w'),
 (b' ', b't'),
 (b'r', b'e'),
 (b'i', b's'),
 (b'w', b'he'),
 (b'whe', b'r'),
 (b'wher', b'e'),
 (b'w', b'h'),
 (b't', b'he'),
 (b'he', b'n'),
 (b'w', b'hen'),
 (b'h', b'i'),
 (b'e', b'r'),
 (b' ', b'y'),
 (b' ', b'where'),
 (b' ', b'when'),
 (b' ', b'the'),
 (b' ', b'is'),
 (b'o', b'u'),
 (b'n', b'n'),
 (b'n', b'i'),
 (b'ni', b'c')]

In [21]:
vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

### Update merge logic

In [189]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

encoded_token_freqs = defaultdict(int)
encoded_token_freqs[(1,2,3,4,2,3)] = 3
encoded_token_freqs[(2,3,5)] = 2
pair_freqs,pair_positions = get_byte_pairs_freqs_and_position(encoded_token_freqs)
pair = select_merge_pair(pair_freqs, vocab)

index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
vocab_len+=1
merges.append((vocab[int(index1)], vocab[int(index2)]))


# pair_freqs = update_pair_freqs_after_merge(pair_freqs, positions, encoded_token_freqs, pair, new_index)



In [None]:
updated_pair_positions = initiate_update(pair_positions,pair)
updated_pair_freqs = initiate_update(pair_freqs,pair)

encoded_token_map = {}
positions = pair_positions[pair]
for position in positions:
    encoded_token, count = list(encoded_token_freqs.items())[position]
    positions = find_subtuple_index(encoded_token,pair)
    encoded_token_map[encoded_token] = update_encoded_token(encoded_token,positions, new_index)
    for pos in positions:
        if pos > 0:
            pre_token = encoded_token[pos-1]
            old_pair = (pre_token,encoded_token[pos])
            new_pair = (pre_token, new_index)
            updated_pair_freqs = remove_or_decrement_pair(updated_pair_freqs, old_pair, count)
            updated_pair_freqs[new_pair]+=count
            updated_pair_positions[new_pair].add(position)
        if pos < len(encoded_token)-2:
            pos_token = encoded_token[pos+2]
            old_pair = (encoded_token[pos+1],pos_token)
            new_pair = (new_index, pos_token)
            updated_pair_freqs = remove_or_decrement_pair(updated_pair_freqs, old_pair, count)
            updated_pair_freqs[new_pair]+=count
            updated_pair_positions[new_pair].add(position)

updated_encoded_token_freqs = update_encoded_token_freqs(encoded_token_freqs,encoded_token_map)
[updated_pair_freqs,updated_encoded_token_freqs,updated_pair_positions]

(1, 2, 3, 4, 2, 3)
(2, 3, 5)


[defaultdict(int, {(1, 257): 3, (257, 4): 3, (4, 257): 3, (257, 5): 2}),
 {(1, 2, 3, 4, 2, 3): 3, (2, 3, 5): 2},
 defaultdict(set,
             {(1, 2): {0},
              (3, 4): {0},
              (4, 2): {0},
              (3, 5): {1},
              (1, 257): {0},
              (257, 4): {0},
              (4, 257): {0},
              (257, 5): {1}})]

In [None]:
update_encoded_token!!!!!

{(1, 2, 3, 4, 2, 3): (1, 2, 3, 4, 2, 3), (2, 3, 5): (2, 3, 5)}

In [185]:
position = 0
encoded_token, count = list(encoded_token_freqs.items())[position]

In [187]:
update_encoded_token(encoded_token, pair, new_index)

(1, 257, 4, 257)

In [183]:
def update_encoded_token(encoded_token, pair, new_index):
    result = []
    i = 0
    while i < len(encoded_token):
        if i < len(encoded_token) - 1 and (encoded_token[i], encoded_token[i + 1]) == pair:
            result.append(new_index)
            i += 2
        else:
            result.append(encoded_token[i])
            i += 1
    return tuple(result)

{(1, 2, 3, 4, 2, 3): (1, 2, 3, 4, 2, 3), (2, 3, 5): (2, 3, 5)}

In [182]:
updated_encoded_token_freqs = {}
for i, (key,value) in enumerate(encoded_token_freqs.items()):
    if key in encoded_token_map:
        print(key)
        new_key = encoded_token_map[key]
        print(new_key)
        updated_encoded_token_freqs[new_key] = value
    else:
        updated_encoded_token_freqs[key] = value
    # print(updated_encoded_token_freqs)
updated_encoded_token_freqs

(1, 2, 3, 4, 2, 3)
(1, 2, 3, 4, 2, 3)
(2, 3, 5)
(2, 3, 5)


{(1, 2, 3, 4, 2, 3): 3, (2, 3, 5): 2}

In [None]:
def update_encoded_token_freqs(encoded_token_freqs: dict[list[int],int],encoded_token_map: dict[list[int], list[int]]) -> dict[list[int],int]:
    updated_encoded_token_freqs = {}
    for i, (key,value) in enumerate(encoded_token_freqs.items()):
        if key in encoded_token_map:
            new_key = encoded_token_map[key]
            updated_encoded_token_freqs[new_key] = value
        else:
            updated_encoded_token_freqs[key] = value
    return updated_encoded_token_freqs

In [None]:

    
    
updated_pair_positions, updated_pair_freqs, updated_encoded_token_freqs = update_all_structures(pair_positions, pair_freqs, encoded_token_freqs, pair, new_index)




[defaultdict(int, {(1, 257): 3, (257, 4): 3, (4, 257): 3, (257, 5): 2}),
 {(1, 2, 3, 4, 2, 3): 3, (2, 3, 5): 2},
 defaultdict(set,
             {(1, 2): {0},
              (3, 4): {0},
              (4, 2): {0},
              (3, 5): {1},
              (1, 257): {0},
              (257, 4): {0},
              (4, 257): {0},
              (257, 5): {1}})]

**updated_encoded_token_freqs is not correct**

{(1, 257, 4, 257): 3, (257, 5): 2}

In [124]:
position = 0
encoded_token, count = list(encoded_token_freqs.items())[position]
pair_positions = find_subtuple_index(encoded_token,pair)

(1, 257, 4, 257)

In [127]:
def apply_function_to_keys(A: dict, positions: list[int], F):
    keys = list(A.keys())  # 保持 key 的顺序
    for pos in positions:
        old_key = keys[pos]
        value = A.pop(old_key)         # 先取出旧值并删除旧 key
        new_key = F(old_key)           # 使用函数 F 得到新 key
        A[new_key] = value   

In [128]:
A = {('a', 'b'): 1, ('b', 'c'): 2, ('c', 'd'): 3}
positions = [0, 2]

# 假设你想把 key 的每个字符都变成大写
def F(key):
    return tuple(k.upper() for k in key)

apply_function_to_keys(A, positions, F)

print(A)

{('b', 'c'): 2, ('A', 'B'): 1, ('C', 'D'): 3}


In [47]:
encoded_token, count = list(encoded_token_freqs.items())[position]
[encoded_token,count]

[(1, 2, 3, 4, 2, 3), 3]

In [48]:
pair_positions = find_subtuple_index(encoded_token,pair)
pair_positions

[1, 4]

In [73]:
pair_freqs

defaultdict(int, {(1, 2): 3, (2, 3): 8, (3, 4): 3, (4, 2): 3, (3, 5): 2})

In [None]:
updated_pair_freqs  = pair_freqs.copy()
updated_pair_freqs[pair] -= count

In [None]:
updated_pair_freqs = pair_freqs.copy()
for pair_position in pair_positions:
    updated_pair_freqs = update_pair_freqs(pair_freqs, encoded_token, pair_position, count)
updated_pair_freqs

defaultdict(int, {(1, 2): 3, (2, 3): 5, (3, 4): 3, (3, 5): 2, (4, 257): 3})

In [None]:
def update_pair_freqs(pair_freqs: dict[tuple[int, int], int], encoded_token: tuple[int], pair_position: int, count: int) -> dict[tuple[int, int], int]:
    updated_pair_freqs  = pair_freqs.copy()
    updated_pair_freqs[pair] -= count
    if pair_position > 0:
        pre_token = encoded_token[pair_position-1]
        pre_pair = (pre_token,encoded_token[pair_position])
        new_pair = (pre_token, new_index)
        if updated_pair_freqs[pre_pair] == count:
            del updated_pair_freqs[pre_pair]
        else:
            updated_pair_freqs[pre_pair]-=count
        updated_pair_freqs[new_pair]+=count
    if pair_position < len(encoded_token)-2:
        pos_token = encoded_token[pair_position+2]
        pos_pair = (encoded_token[pair_position+1],pos_token)
        new_pair = (new_index, pos_token)
        if updated_pair_freqs[pos_pair] == count:
            del updated_pair_freqs[pos_pair]
        else:
            updated_pair_freqs[pos_pair]-=count
        updated_pair_freqs[new_pair]+=count
    return updated_pair_freqs

defaultdict(int, {(1, 2): 3, (2, 3): 8, (3, 4): 3, (4, 2): 3, (3, 5): 2})

(2, 3)

In [59]:
len(encoded_token)-2

4

3

In [55]:
pair_freqs

defaultdict(int, {(1, 2): 3, (2, 3): 8, (3, 4): 3, (4, 2): 3, (3, 5): 2})

## Unit test

In [457]:
input_path = r"./tests/fixtures/tinystories_sample_5M.txt"

In [458]:
vocab, merges = train_bpe(
        input_path=input_path,
        vocab_size=1000,
        special_tokens=["<|endoftext|>"],
    )

In [459]:
vocabs_without_specials = [word for word in vocab.values() if word != b"<|endoftext|>"]
for word_bytes in vocabs_without_specials:
    assert b"<|" not in word_bytes

In [460]:
vocab_size = 1000
special_tokens=["<|endoftext|>"]
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

In [461]:
import pickle

with open("./tests/_snapshots/test_train_bpe_special_tokens.pkl", "rb") as f:
    data = pickle.load(f)
    
data.keys()

dict_keys(['vocab_keys', 'vocab_values', 'merges'])

In [462]:
for i, (my_merge, target_merge) in enumerate(zip(merges,data['merges'])):
    if not my_merge == target_merge:
        print([i, my_merge, target_merge])
        break

## Q1.2 Problem (train_bpe_tinystories): BPE Training on TinyStories 

In [None]:
%load_ext memory_profiler

In [14]:
import time

In [21]:
input_path = r"./data/TinyStoriesV2-GPT4-valid.txt"
special_tokens = ['<|endoftext|>']
vocab_size = 32000

In [22]:
start_time = time.time()
%memit vocab, merges = train_bpe(input_path, vocab_size, special_tokens)
end_time = time.time()
print(f"⏱️ 运行时间: {end_time - start_time:.2f} 秒")

peak memory: 356.20 MiB, increment: 23.55 MiB
⏱️ 运行时间: 85.28 秒


In [25]:
input_path = r'./data/TinyStoriesV2-GPT4-train.txt'

extracted

'TinyStoriesV2-GPT4-train'