In [1]:
import pandas as pd
from collections import defaultdict
import regex as re
from multiprocessing import Pool
from support.find_chunk_boundaries import find_chunk_boundaries
from memory_profiler import profile

## Q1.1 Problem (train_bpe): BPE Tokenizer Training 


In [None]:







def update_element_counts(encoded_token_freqs: dict[str, int], pair: tuple[int, int], new_index: int) -> dict[str, int]:
    updated_byte_level_counts = {}
    for elements, counts in encoded_token_freqs.items():
        new_element = []
        elements_len = len(elements)
        index = 0
        while index <= elements_len-1:
            if (index < elements_len-1) and (elements[index] == pair[0]) and (elements[index+1] == pair[1]):
                new_element.append(new_index)
                index += 2
            else:
                new_element.append(elements[index])
                index += 1
        new_byte_level_counts[tuple(new_element)] = counts
    return new_byte_level_counts  

  




def pre_tokenize(string: str,special_tokens: list[str]) -> dict[str, int]:
    string_list = split_by_special_tokens(string, special_tokens)
    token_freqs = count_tokens(string_list)
    return token_freqs

def tokenize_chunk(args):
    chunk_text, special_tokens = args
    return pre_tokenize(chunk_text, special_tokens)

def read_text_chunks(input_path: str, special_tokens: list[str], num_processes: int = 4) -> list[str]:
    with open(input_path, "rb") as f:
        num_processes = 4
        boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

        chunks = []
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            f.seek(start)
            chunk = f.read(end - start).decode("utf-8", errors="ignore")
            chunks.append((chunk, special_tokens))
    return chunks

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    string = read_text_file(input_path)
    token_freqs = pre_tokenize(string,special_tokens)
    # num_processes = 4
    # chunks = read_text_chunks(input_path, special_tokens, num_processes)
    # with Pool(num_processes) as pool:
    #     results = pool.map(tokenize_chunk, chunks)
        
    # token_freqs = defaultdict(int)
    # for partial_count in results:
    #     for word, count in partial_count.items():
    #         token_freqs[word] += count
            
    encoded_token_freqs = encode_and_count_tokens(token_freqs)
    vocab = build_initial_vocab(special_tokens)
    vocab_len = len(vocab)
    merges = []
    while vocab_len<vocab_size:
        pair_freqs = count_byte_pairs(encoded_token_freqs)
        if len(pair_freqs) == 0:
            break
        pair = select_merge_pair(pair_freqs, vocab)
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        encoded_token_freqs = update_element_counts(encoded_token_freqs, pair,new_index)
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        vocab[new_index] = new_token
        vocab_len+=1
    return vocab, merges

In [None]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'

epochs = 6
vocab_size = 270
string = "hi. i'm yifan li. nice to meet you.<|endoftext|> this the what when here where"
input_path = r'./data/TinyStoriesV2-GPT4-valid.txt'
string = read_text_file(input_path)

In [516]:
%memit vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

peak memory: 861.81 MiB, increment: 20.62 MiB


## 个人优化

In [188]:
def read_text_file(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_by_special_tokens(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def count_tokens(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def encode_and_count_tokens(counts: dict[str, int])-> dict[str, int]:
    encoded_token_freqs = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(token.encode("utf-8"))
        encoded_token_freqs[elements] += count
    return encoded_token_freqs

def build_initial_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab 

def get_byte_pairs_freqs_and_position(encoded_token_freqs: dict[str, int]) -> tuple[dict[tuple[int,int], int], dict[tuple[int,int], set[int]]]:
    pair_freqs = defaultdict(int)
    pair_position = defaultdict(set)
    for index, (elements, count) in enumerate(encoded_token_freqs.items()):
        for i in range(len(elements)-1):
            pair_freqs[(elements[i],elements[i+1])] += count
            pair_position[(elements[i],elements[i+1])].add(index)
    return pair_freqs, pair_position

def select_merge_pair(pair_freqs: dict[tuple[int,int], int], vocab:dict[int, bytes]) -> tuple[int, int]:
    max_count = max(pair_freqs.values())
    candidate_pairs = [key for key, value in pair_freqs.items() if value == max_count]
    def sort_pair(pair):
        index1, index2 = pair
        return(vocab[index1], vocab[index2])
    pair = max(candidate_pairs, key = sort_pair)
    return pair

def find_subtuple_index(sequence: tuple, subseq: tuple) -> list[int]:
    position = []
    subseq_len = len(subseq)
    for i in range(len(sequence)-subseq_len+1):
        if sequence[i:i+subseq_len] == subseq:
            position.append(i)
    return position











def update_encoded_token(encoded_token, pair, new_index):
    result = []
    i = 0
    while i < len(encoded_token):
        if i < len(encoded_token) - 1 and (encoded_token[i], encoded_token[i + 1]) == pair:
            result.append(new_index)
            i += 2
        else:
            result.append(encoded_token[i])
            i += 1
    return tuple(result)

def remove_or_decrement_pair(pair_freqs: dict[tuple[int,int], int], pair: tuple[int,int], count: int):
    updated_pair_freqs = pair_freqs.copy()
    if updated_pair_freqs[pair] == count:
        del updated_pair_freqs[pair]
    else:
        updated_pair_freqs[pair] -= count
    return updated_pair_freqs

def update_encoded_token_freqs(encoded_token_freqs: dict[list[int],int],encoded_token_map: dict[list[int], list[int]]) -> dict[list[int],int]:
    updated_encoded_token_freqs = {}
    for i, (key,value) in enumerate(encoded_token_freqs.items()):
        if key in encoded_token_map:
            new_key = encoded_token_map[key]
            updated_encoded_token_freqs[new_key] = value
        else:
            updated_encoded_token_freqs[key] = value
    return updated_encoded_token_freqs

def initiate_update(x, pair):
    y = x.copy()
    del y[pair]
    return y



def update_all_structures(pair_positions, pair_freqs, encoded_token_freqs, pair, new_index):
    updated_pair_positions = initiate_update(pair_positions,pair)
    updated_pair_freqs = initiate_update(pair_freqs,pair)

    encoded_token_map = {}
    positions = pair_positions[pair]
    for position in positions:
        encoded_token, count = list(encoded_token_freqs.items())[position]
        positions = find_subtuple_index(encoded_token,pair)
        encoded_token_map[encoded_token] = update_encoded_token(encoded_token,positions, new_index)
        for pos in positions:
            if pos > 0:
                pre_token = encoded_token[pos-1]
                old_pair = (pre_token,encoded_token[pos])
                new_pair = (pre_token, new_index)
                updated_pair_freqs = remove_or_decrement_pair(updated_pair_freqs, old_pair, count)
                updated_pair_freqs[new_pair]+=count
                updated_pair_positions[new_pair].add(position)
            if pos < len(encoded_token)-2:
                pos_token = encoded_token[pos+2]
                old_pair = (encoded_token[pos+1],pos_token)
                new_pair = (new_index, pos_token)
                updated_pair_freqs = remove_or_decrement_pair(updated_pair_freqs, old_pair, count)
                updated_pair_freqs[new_pair]+=count
                updated_pair_positions[new_pair].add(position)

    updated_encoded_token_freqs = update_encoded_token_freqs(encoded_token_freqs,encoded_token_map)
    return updated_pair_positions, updated_pair_freqs, updated_encoded_token_freqs



In [166]:
special_tokens = ['<|endoftext|>']
input_path = r'./data/test.txt'
vocab_size = 260
string = "hi. i'm yifan li. nice to meet you.<|endoftext|> this the what when here where"

In [167]:
string_list = split_by_special_tokens(string, special_tokens)
token_freqs = count_tokens(string_list)
encoded_token_freqs = encode_and_count_tokens(token_freqs)
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

pair_freqs = count_byte_pairs(encoded_token_freqs)

### Update merge logic

In [189]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

encoded_token_freqs = defaultdict(int)
encoded_token_freqs[(1,2,3,4,2,3)] = 3
encoded_token_freqs[(2,3,5)] = 2
pair_freqs,pair_positions = get_byte_pairs_freqs_and_position(encoded_token_freqs)
pair = select_merge_pair(pair_freqs, vocab)

index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
vocab_len+=1
merges.append((vocab[int(index1)], vocab[int(index2)]))


# pair_freqs = update_pair_freqs_after_merge(pair_freqs, positions, encoded_token_freqs, pair, new_index)



In [None]:
updated_pair_positions = initiate_update(pair_positions,pair)
updated_pair_freqs = initiate_update(pair_freqs,pair)

encoded_token_map = {}
positions = pair_positions[pair]
for position in positions:
    encoded_token, count = list(encoded_token_freqs.items())[position]
    positions = find_subtuple_index(encoded_token,pair)
    encoded_token_map[encoded_token] = update_encoded_token(encoded_token,positions, new_index)
    for pos in positions:
        if pos > 0:
            pre_token = encoded_token[pos-1]
            old_pair = (pre_token,encoded_token[pos])
            new_pair = (pre_token, new_index)
            updated_pair_freqs = remove_or_decrement_pair(updated_pair_freqs, old_pair, count)
            updated_pair_freqs[new_pair]+=count
            updated_pair_positions[new_pair].add(position)
        if pos < len(encoded_token)-2:
            pos_token = encoded_token[pos+2]
            old_pair = (encoded_token[pos+1],pos_token)
            new_pair = (new_index, pos_token)
            updated_pair_freqs = remove_or_decrement_pair(updated_pair_freqs, old_pair, count)
            updated_pair_freqs[new_pair]+=count
            updated_pair_positions[new_pair].add(position)

updated_encoded_token_freqs = update_encoded_token_freqs(encoded_token_freqs,encoded_token_map)
[updated_pair_freqs,updated_encoded_token_freqs,updated_pair_positions]

(1, 2, 3, 4, 2, 3)
(2, 3, 5)


[defaultdict(int, {(1, 257): 3, (257, 4): 3, (4, 257): 3, (257, 5): 2}),
 {(1, 2, 3, 4, 2, 3): 3, (2, 3, 5): 2},
 defaultdict(set,
             {(1, 2): {0},
              (3, 4): {0},
              (4, 2): {0},
              (3, 5): {1},
              (1, 257): {0},
              (257, 4): {0},
              (4, 257): {0},
              (257, 5): {1}})]

In [None]:
update_encoded_token!!!!!

{(1, 2, 3, 4, 2, 3): (1, 2, 3, 4, 2, 3), (2, 3, 5): (2, 3, 5)}

In [185]:
position = 0
encoded_token, count = list(encoded_token_freqs.items())[position]

In [187]:
update_encoded_token(encoded_token, pair, new_index)

(1, 257, 4, 257)

In [183]:
def update_encoded_token(encoded_token, pair, new_index):
    result = []
    i = 0
    while i < len(encoded_token):
        if i < len(encoded_token) - 1 and (encoded_token[i], encoded_token[i + 1]) == pair:
            result.append(new_index)
            i += 2
        else:
            result.append(encoded_token[i])
            i += 1
    return tuple(result)

{(1, 2, 3, 4, 2, 3): (1, 2, 3, 4, 2, 3), (2, 3, 5): (2, 3, 5)}

In [182]:
updated_encoded_token_freqs = {}
for i, (key,value) in enumerate(encoded_token_freqs.items()):
    if key in encoded_token_map:
        print(key)
        new_key = encoded_token_map[key]
        print(new_key)
        updated_encoded_token_freqs[new_key] = value
    else:
        updated_encoded_token_freqs[key] = value
    # print(updated_encoded_token_freqs)
updated_encoded_token_freqs

(1, 2, 3, 4, 2, 3)
(1, 2, 3, 4, 2, 3)
(2, 3, 5)
(2, 3, 5)


{(1, 2, 3, 4, 2, 3): 3, (2, 3, 5): 2}

In [None]:
def update_encoded_token_freqs(encoded_token_freqs: dict[list[int],int],encoded_token_map: dict[list[int], list[int]]) -> dict[list[int],int]:
    updated_encoded_token_freqs = {}
    for i, (key,value) in enumerate(encoded_token_freqs.items()):
        if key in encoded_token_map:
            new_key = encoded_token_map[key]
            updated_encoded_token_freqs[new_key] = value
        else:
            updated_encoded_token_freqs[key] = value
    return updated_encoded_token_freqs

In [None]:

    
    
updated_pair_positions, updated_pair_freqs, updated_encoded_token_freqs = update_all_structures(pair_positions, pair_freqs, encoded_token_freqs, pair, new_index)




[defaultdict(int, {(1, 257): 3, (257, 4): 3, (4, 257): 3, (257, 5): 2}),
 {(1, 2, 3, 4, 2, 3): 3, (2, 3, 5): 2},
 defaultdict(set,
             {(1, 2): {0},
              (3, 4): {0},
              (4, 2): {0},
              (3, 5): {1},
              (1, 257): {0},
              (257, 4): {0},
              (4, 257): {0},
              (257, 5): {1}})]

**updated_encoded_token_freqs is not correct**

{(1, 257, 4, 257): 3, (257, 5): 2}

In [124]:
position = 0
encoded_token, count = list(encoded_token_freqs.items())[position]
pair_positions = find_subtuple_index(encoded_token,pair)

(1, 257, 4, 257)

In [127]:
def apply_function_to_keys(A: dict, positions: list[int], F):
    keys = list(A.keys())  # 保持 key 的顺序
    for pos in positions:
        old_key = keys[pos]
        value = A.pop(old_key)         # 先取出旧值并删除旧 key
        new_key = F(old_key)           # 使用函数 F 得到新 key
        A[new_key] = value   

In [128]:
A = {('a', 'b'): 1, ('b', 'c'): 2, ('c', 'd'): 3}
positions = [0, 2]

# 假设你想把 key 的每个字符都变成大写
def F(key):
    return tuple(k.upper() for k in key)

apply_function_to_keys(A, positions, F)

print(A)

{('b', 'c'): 2, ('A', 'B'): 1, ('C', 'D'): 3}


In [47]:
encoded_token, count = list(encoded_token_freqs.items())[position]
[encoded_token,count]

[(1, 2, 3, 4, 2, 3), 3]

In [48]:
pair_positions = find_subtuple_index(encoded_token,pair)
pair_positions

[1, 4]

In [73]:
pair_freqs

defaultdict(int, {(1, 2): 3, (2, 3): 8, (3, 4): 3, (4, 2): 3, (3, 5): 2})

In [None]:
updated_pair_freqs  = pair_freqs.copy()
updated_pair_freqs[pair] -= count

In [None]:
updated_pair_freqs = pair_freqs.copy()
for pair_position in pair_positions:
    updated_pair_freqs = update_pair_freqs(pair_freqs, encoded_token, pair_position, count)
updated_pair_freqs

defaultdict(int, {(1, 2): 3, (2, 3): 5, (3, 4): 3, (3, 5): 2, (4, 257): 3})

In [None]:
def update_pair_freqs(pair_freqs: dict[tuple[int, int], int], encoded_token: tuple[int], pair_position: int, count: int) -> dict[tuple[int, int], int]:
    updated_pair_freqs  = pair_freqs.copy()
    updated_pair_freqs[pair] -= count
    if pair_position > 0:
        pre_token = encoded_token[pair_position-1]
        pre_pair = (pre_token,encoded_token[pair_position])
        new_pair = (pre_token, new_index)
        if updated_pair_freqs[pre_pair] == count:
            del updated_pair_freqs[pre_pair]
        else:
            updated_pair_freqs[pre_pair]-=count
        updated_pair_freqs[new_pair]+=count
    if pair_position < len(encoded_token)-2:
        pos_token = encoded_token[pair_position+2]
        pos_pair = (encoded_token[pair_position+1],pos_token)
        new_pair = (new_index, pos_token)
        if updated_pair_freqs[pos_pair] == count:
            del updated_pair_freqs[pos_pair]
        else:
            updated_pair_freqs[pos_pair]-=count
        updated_pair_freqs[new_pair]+=count
    return updated_pair_freqs

defaultdict(int, {(1, 2): 3, (2, 3): 8, (3, 4): 3, (4, 2): 3, (3, 5): 2})

(2, 3)

In [59]:
len(encoded_token)-2

4

3

In [55]:
pair_freqs

defaultdict(int, {(1, 2): 3, (2, 3): 8, (3, 4): 3, (4, 2): 3, (3, 5): 2})

## Unit test

In [457]:
input_path = r"./tests/fixtures/tinystories_sample_5M.txt"

In [458]:
vocab, merges = train_bpe(
        input_path=input_path,
        vocab_size=1000,
        special_tokens=["<|endoftext|>"],
    )

In [459]:
vocabs_without_specials = [word for word in vocab.values() if word != b"<|endoftext|>"]
for word_bytes in vocabs_without_specials:
    assert b"<|" not in word_bytes

In [460]:
vocab_size = 1000
special_tokens=["<|endoftext|>"]
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

In [461]:
import pickle

with open("./tests/_snapshots/test_train_bpe_special_tokens.pkl", "rb") as f:
    data = pickle.load(f)
    
data.keys()

dict_keys(['vocab_keys', 'vocab_values', 'merges'])

In [462]:
for i, (my_merge, target_merge) in enumerate(zip(merges,data['merges'])):
    if not my_merge == target_merge:
        print([i, my_merge, target_merge])
        break

## Q1.2 Problem (train_bpe_tinystories): BPE Training on TinyStories 

In [None]:
%load_ext memory_profiler

In [14]:
import time

In [21]:
input_path = r"./data/TinyStoriesV2-GPT4-valid.txt"
special_tokens = ['<|endoftext|>']
vocab_size = 32000

In [22]:
start_time = time.time()
%memit vocab, merges = train_bpe(input_path, vocab_size, special_tokens)
end_time = time.time()
print(f"⏱️ 运行时间: {end_time - start_time:.2f} 秒")

peak memory: 356.20 MiB, increment: 23.55 MiB
⏱️ 运行时间: 85.28 秒


In [25]:
input_path = r'./data/TinyStoriesV2-GPT4-train.txt'

extracted

'TinyStoriesV2-GPT4-train'