In [440]:
import pandas as pd
import numpy as np
from collections import defaultdict
import regex as re
from multiprocessing import Pool
from support.find_chunk_boundaries import find_chunk_boundaries
from memory_profiler import profile
import time, tracemalloc
from dataclasses import dataclass
from typing import BinaryIO, Iterable, Iterator
import random

## Q1.1 Problem (train_bpe): BPE Tokenizer Training 


In [None]:
Pair = tuple[int,int]
Encoded_Token = tuple[int, ...]
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [None]:
@dataclass(frozen=True)
class TokenMergePlan():
    old_token: Encoded_Token
    new_token: Encoded_Token
    count: int
    pair_positions: list[int]

class PairFreqsDelta():
    inc: defaultdict[Pair, int]
    inc: defaultdict[Pair, int]
    def __init__(self):
        self.inc = defaultdict(int)
        self.dec = defaultdict(int)

class PairInhereitDelta():
    add: defaultdict[Pair,set[Encoded_Token]]
    remove: defaultdict[Pair,set[Encoded_Token]]
    def __init__(self):
        self.add = defaultdict(set[Encoded_Token])
        self.remove = defaultdict(set[Encoded_Token])

In [None]:
def pre_tokenize(string: str,special_tokens: list[str]) -> dict[str, int]:
    string_list = split_by_special_tokens(string, special_tokens)
    token_freqs = count_tokens(string_list)
    return token_freqs

def tokenize_chunk(args):
    chunk_text, special_tokens = args
    return pre_tokenize(chunk_text, special_tokens)

def read_text_chunks(input_path: str, special_tokens: list[str], num_processes: int = 4) -> list[str]:
    with open(input_path, "rb") as f:
        boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

        chunks = []
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            f.seek(start)
            chunk = f.read(end - start).decode("utf-8", errors="ignore")
            chunks.append((chunk, special_tokens))
    return chunks


def combine_counts(results: list[dict[str,int]]) -> dict[str,int]:
    token_freqs = defaultdict(int)
    for partial_count in results:
        for word, count in partial_count.items():
            token_freqs[word] += count
    return token_freqs

def read_text_file(input_path: str) -> str:
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    return text

def split_by_special_tokens(string: str, special_tokens: list[str]) -> list[str]:
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    return re.split(pattern,string)



def count_tokens(string_list: list[str]) -> dict[str, int]:
    counts = defaultdict(int)
    for s in string_list:
        tokens = re.finditer(PAT, s)
        for m in tokens:
            tok = m.group(0)
            counts[tok] += 1
    return counts

def encode_and_count_tokens(counts: dict[str, int])-> dict[Encoded_Token, int]:
    encoded_token_freqs = defaultdict(int)
    for token, count in counts.items():
        elements = tuple(bytes([b]) for b in token.encode("utf-8"))
        encoded_token_freqs[elements] += count
    return encoded_token_freqs

def build_initial_vocab(special_tokens: list[str]) ->  dict[int, bytes]:
    vocab = {i:bytes([i]) for i in range(256)}
    for i, tok in enumerate(special_tokens, start=256):
        vocab[i] = tok.encode("utf-8")
    return vocab

def get_byte_pairs(encoded_token_freqs: dict[Encoded_Token, int]) -> dict[Pair, int]:
    pair_freqs = defaultdict(int)
    for tok, count in encoded_token_freqs.items():
        for i in range(len(tok)-1):
            pair_freqs[(tok[i],tok[i+1])] += count
    return pair_freqs

def get_byte_pairs_inhereit(encoded_token_freqs: dict[Encoded_Token, int]) -> dict[Pair, set[Encoded_Token]]:
    pair_inhereit = defaultdict(set)
    for tok, count in encoded_token_freqs.items():
        for i in range(len(tok)-1):
            pair_inhereit[(tok[i],tok[i+1])].add(tok)
    return pair_inhereit

def select_merge_pair(pair_freqs: dict[Pair, int], vocab:dict[int, bytes]) -> Pair:
    max_count = max(pair_freqs.values())
    candidate_pairs = [key for key, value in pair_freqs.items() if value == max_count]
    def sort_pair(pair):
        index1, index2 = pair
        return(vocab[index1], vocab[index2])
    pair = max(candidate_pairs, key = sort_pair)
    return pair

def update_encoded_token(encoded_token: Encoded_Token, pair: Pair, new_index: int) -> Encoded_Token:
    result = []
    i = 0
    while i < len(encoded_token):
        if i < len(encoded_token) - 1 and (encoded_token[i], encoded_token[i + 1]) == pair:
            result.append(new_index)
            i += 2
        else:
            result.append(encoded_token[i])
            i += 1
    return tuple(result)

def find_subtuple_index(sequence: tuple, subseq: tuple) -> list[int]:
    position = []
    subseq_len = len(subseq)
    for i in range(len(sequence)-subseq_len+1):
        if sequence[i:i+subseq_len] == subseq:
            position.append(i)
    return position

def remove_or_decrement_pair(pair_freqs: dict[Pair, int], pair: Pair, count: int):
    pair_freqs[pair] -= count
    if pair_freqs[pair] <= 0:
        del pair_freqs[pair]
        
def remove_or_decrement_pair2(pair_freqs: dict[Pair, set], pair: Pair, to_remove: set):
    pair_freqs[pair] -= to_remove  # 使用 set 差集
    if not pair_freqs[pair]:       # 如果变成空集
        del pair_freqs[pair]

def build_merge_plan(tok_need_update: set[Encoded_Token], encoded_token_freqs: dict[Encoded_Token, int], pair: Pair, new_index: int) -> list[TokenMergePlan]:
    plan = []
    for encoded_token in tok_need_update:
        new_encoded_token = update_encoded_token(encoded_token, pair, new_index)
        count = encoded_token_freqs[encoded_token]
        pair_positions = find_subtuple_index(encoded_token,pair)
        plan.append(TokenMergePlan(encoded_token,new_encoded_token,count,pair_positions))
    return plan

def update_encoded_token_freqs(plan: list[TokenMergePlan], encoded_token_freqs: dict[Encoded_Token, int]):
    for item in plan:
        count =  encoded_token_freqs.pop(item.old_token)
        encoded_token_freqs[item.new_token] = count

def compute_freqs_deltas(plan: list[TokenMergePlan], new_index:int) -> PairFreqsDelta:
    pair_freqs_d = PairFreqsDelta()
    for item in plan:
        old_token = item.old_token
        count = item.count
        for pos in item.pair_positions:
            if pos > 0:
                pre_token = old_token[pos-1]
                old_pair = (pre_token,old_token[pos])
                new_pair = (pre_token, new_index)
                pair_freqs_d.dec[old_pair] += count
                pair_freqs_d.inc[new_pair] += count
            if pos < len(old_token)-2:
                pos_token = old_token[pos+2]
                old_pair = (old_token[pos+1],pos_token)
                new_pair = (new_index, pos_token)
                pair_freqs_d.dec[old_pair] += count
                pair_freqs_d.inc[new_pair] += count
    return pair_freqs_d

def compute_inhereit_deltas(plan: list[TokenMergePlan]) -> PairInhereitDelta:
    pair_inhereit_d = PairInhereitDelta()
    for item in plan:
        if len(item.old_token) > 1:
            for old_pair in zip(item.old_token,item.old_token[1:]):
                pair_inhereit_d.remove[old_pair].add(item.old_token)
        if len(item.new_token) > 1:
            for new_pair in zip(item.new_token,item.new_token[1:]):
                pair_inhereit_d.add[new_pair].add(item.new_token)
    return pair_inhereit_d 

def exclude_pair_from_dict(d: dict, pair: Pair):
    del d[pair]

def update_pair_freqs(pair_freqs: dict[Pair, int], pair_freqs_d: PairFreqsDelta):
    for key, value in pair_freqs_d.dec.items():
        remove_or_decrement_pair(pair_freqs, key, value)
    for key, value in pair_freqs_d.inc.items():
        pair_freqs[key]+=value

def update_pair_inhereit(pair_inhereit: dict[Pair, set[Encoded_Token]], pair_inhereit_d: PairInhereitDelta):
    for key, value in pair_inhereit_d.remove.items():
        remove_or_decrement_pair2(pair_inhereit, key, value)
    for key, value in pair_inhereit_d.add.items():
        pair_inhereit[key] = pair_inhereit[key] | value

def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    vocab = build_initial_vocab(special_tokens)
    vocab_len = len(vocab)
    merges = []

    string = read_text_file(input_path)
    string_list = split_by_special_tokens(string,special_tokens)
    token_count = count_tokens(string_list)
    encoded_token_freqs = encode_and_count_tokens(token_count)
    
    pair_freqs = get_byte_pairs(encoded_token_freqs)
    pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

    while vocab_len < vocab_size:
        pair = select_merge_pair(pair_freqs, vocab)
        index1, index2 = pair
        new_token = vocab[int(index1)]+vocab[int(index2)]
        new_index = vocab_len
        vocab[new_index] = new_token
        merges.append((vocab[int(index1)], vocab[int(index2)]))
        tok_need_update = pair_inhereit[pair]
        plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
        update_encoded_token_freqs(plan, encoded_token_freqs)
        pair_freqs_d = compute_freqs_deltas(plan, new_index)
        exclude_pair_from_dict(pair_freqs, pair)
        update_pair_freqs(pair_freqs, pair_freqs_d)
        pair_inhereit_d = compute_inhereit_deltas(plan)
        exclude_pair_from_dict(pair_inhereit, pair)
        update_pair_inhereit(pair_inhereit, pair_inhereit_d)
    return vocab, merges

## Encoder and Decoder

In [None]:
class Tokenizer():
    def __init__(self, vocab:  dict[int, bytes], merges: list[Pair], special_tokens: list[str] | None=None):
        self.vocab = vocab
        self.merges = merges
        self.special_tokens = special_tokens
        
        self.add_special_tokes_to_vocab()
        self.generate_tok_to_id()
        
        
    def add_special_tokes_to_vocab(self):
        if self.special_tokens is not None:
            toks = sorted(self.special_tokens, key=len, reverse=True)
            vocab_len = len(self.vocab)
            vocab_values = set(x for x in self.vocab.values())
            for tok in toks:
                if bytes(tok.encode("utf-8")) not in vocab_values:
                    self.vocab[vocab_len] = bytes(tok.encode("utf-8"))
                    vocab_len += 1
                    
    def generate_tok_to_id(self):
        self.tok_to_id = {value: key for key, value in self.vocab.items()}
        
        
    @classmethod
    def from_files(
        cls, vocab_filepath: str, merges_filepath: str, special_tokens: list[str] | None = None
    ):
        """
        Class method that constructs and return a Tokenizer from a serialized vocabulary and list of merges
        (in the same format that your BPE training code output) and (optionally) a list of special
        tokens.
        """
        import pickle

        with open(vocab_filepath, "rb") as vf:
            vocab: dict[int, bytes] = pickle.load(vf)
        with open(merges_filepath, "rb") as mf:
            merges: list[Pair] = pickle.load(mf)
        return cls(vocab=vocab, merges=merges, special_tokens=special_tokens)
    
    def encode(self, text: str) -> list[int]:
        segments = self.split_text_to_segments(text)
        encoded_segments, _ = self.encode_segments(segments)
        return np.array(encoded_segments)
    
    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        for chunk in iterable:
            for _id in self.encode(chunk):
                yield _id
    
    def decode(self, ids: list[int]) -> str:
        text_list = [
            self.vocab[id].decode('utf-8', errors='replace') if id in self.vocab
            else '\uFFFD'
            for id in ids
        ]
        text = "".join(text_list)
        return text
    
    
    ## Helper
    def split_text_to_segments(self, text: str) -> list[str]:
        text = text.replace("\r\n", "\n").replace("\r", "")
        if self.special_tokens is None:
            return [text]
        else:
            pattern = "("+"|".join(re.escape(tok) for tok in self.special_tokens)+")"
            segments = re.split(pattern,text)
        return segments
    
    def encode_segments(self, segments:list[str]) -> tuple[list[int], list[str]]:
        encoded_segments = []
        merged_segments = []
        for seg in segments:
            encoded_s, merged_s = self.encode_seg(seg)
            encoded_segments += encoded_s
            merged_segments += merged_s
        return encoded_segments, merged_segments
            
    
    def encode_seg(self, seg:str) -> tuple[list[int], list[str]]:
        merged_s = []
        encoded_s = []
        if (self.special_tokens is not None) and (seg in self.special_tokens):
            merged_s += [bytes(seg.encode("utf-8"))]
            encoded_s += [self.tok_to_id[bytes(seg.encode("utf-8"))]]
        else:
            tokens = re.finditer(PAT, seg)
            for m in tokens:
                tok = m.group(0)
                bytes_tok = tuple(bytes([b]) for b in tok.encode("utf-8"))
                merged_tok =  self.bpe_encoding_tok(bytes_tok)
                merged_s += merged_tok
                encoder_tok = []
                for i in merged_tok:
                    encoder_tok.append(self.tok_to_id[i])
                encoded_s += (encoder_tok)
        return encoded_s, merged_s
    
    def bpe_encoding_tok(self, bytes_tok: tuple[bytes,...]) -> tuple[bytes,...]:
        start_len = len(bytes_tok)+1
        while len(bytes_tok) != start_len:
            start_len = len(bytes_tok)
            for pair in self.merges:
                pos = self.is_subtuple(pair, bytes_tok)
                if pos != -1:
                    bytes_tok = bytes_tok[:pos]+(bytes_tok[pos]+bytes_tok[pos+1],)+bytes_tok[pos+2:]
                    break
        return bytes_tok
    
    
    def is_subtuple(self,small: tuple, big: tuple) -> int:
        n, m = len(small), len(big)
        for i in range(m - n + 1):
            if big[i:i+n] == small:
                return i
        return -1




In [434]:
import pickle

with open("./data/TinyStoriesV2-GPT4-train_10000_vocab.pkl", "rb") as f:
    vocab = pickle.load(f)

with open("./data/TinyStoriesV2-GPT4-train_10000_merges.pkl", "rb") as f:
    merges = pickle.load(f)
    
special_tokens = ['<|endoftext|>','<|startoftext|>']

In [435]:
tokenizer = Tokenizer(vocab, merges, special_tokens)

with open("./data/TinyStoriesV2-GPT4-valid.txt", "r", encoding="utf-8") as f:
    data = f.read()
documents = data.split("<|endoftext|>")
documents = [doc.strip() for doc in documents if doc.strip()]
sampled_docs = random.sample(documents, 10)

for sampled_doc in sampled_docs:
    encoded_tokens = tokenizer.encode(sampled_doc)
    num_bytes = len(sampled_doc.encode("utf-8"))
    num_tokens = len(encoded_tokens)
    bytes_per_token = num_bytes / num_tokens
    print(f"Bytes: {num_bytes}, Tokens: {num_tokens}, Ratio (bytes/token): {bytes_per_token:.2f}")

Bytes: 688, Tokens: 176, Ratio (bytes/token): 3.91
Bytes: 775, Tokens: 198, Ratio (bytes/token): 3.91
Bytes: 631, Tokens: 151, Ratio (bytes/token): 4.18
Bytes: 955, Tokens: 240, Ratio (bytes/token): 3.98
Bytes: 948, Tokens: 207, Ratio (bytes/token): 4.58
Bytes: 2526, Tokens: 639, Ratio (bytes/token): 3.95
Bytes: 573, Tokens: 133, Ratio (bytes/token): 4.31
Bytes: 595, Tokens: 151, Ratio (bytes/token): 3.94
Bytes: 540, Tokens: 143, Ratio (bytes/token): 3.78
Bytes: 634, Tokens: 155, Ratio (bytes/token): 4.09


In [432]:
import pickle

with open("./data/owt_train_32000_vocab.pkl", "rb") as f:
    vocab = pickle.load(f)

with open("./data/owt_train_32000_merges.pkl", "rb") as f:
    merges = pickle.load(f)
    
special_tokens = ['<|endoftext|>','<|startoftext|>']

In [433]:
tokenizer = Tokenizer(vocab, merges, special_tokens)

with open("./data/owt_valid.txt", "r", encoding="utf-8") as f:
    data = f.read()
documents = data.split("<|endoftext|>")
documents = [doc.strip() for doc in documents if doc.strip()]
sampled_docs = random.sample(documents, 10)

for sampled_doc in sampled_docs:
    encoded_tokens = tokenizer.encode(sampled_doc)
    num_bytes = len(sampled_doc.encode("utf-8"))
    num_tokens = len(encoded_tokens)
    bytes_per_token = num_bytes / num_tokens
    print(f"Bytes: {num_bytes}, Tokens: {num_tokens}, Ratio (bytes/token): {bytes_per_token:.2f}")

Bytes: 7775, Tokens: 1641, Ratio (bytes/token): 4.74
Bytes: 3545, Tokens: 779, Ratio (bytes/token): 4.55
Bytes: 14206, Tokens: 3226, Ratio (bytes/token): 4.40
Bytes: 2007, Tokens: 445, Ratio (bytes/token): 4.51
Bytes: 4100, Tokens: 897, Ratio (bytes/token): 4.57
Bytes: 7657, Tokens: 1637, Ratio (bytes/token): 4.68
Bytes: 1127, Tokens: 245, Ratio (bytes/token): 4.60
Bytes: 1458, Tokens: 315, Ratio (bytes/token): 4.63
Bytes: 6856, Tokens: 1504, Ratio (bytes/token): 4.56
Bytes: 2548, Tokens: 519, Ratio (bytes/token): 4.91


In [436]:
import pickle

with open("./data/TinyStoriesV2-GPT4-train_10000_vocab.pkl", "rb") as f:
    vocab = pickle.load(f)

with open("./data/TinyStoriesV2-GPT4-train_10000_merges.pkl", "rb") as f:
    merges = pickle.load(f)
    
special_tokens = ['<|endoftext|>','<|startoftext|>']

In [437]:
tokenizer = Tokenizer(vocab, merges, special_tokens)

with open("./data/owt_valid.txt", "r", encoding="utf-8") as f:
    data = f.read()
documents = data.split("<|endoftext|>")
documents = [doc.strip() for doc in documents if doc.strip()]
sampled_docs = random.sample(documents, 10)

for sampled_doc in sampled_docs:
    encoded_tokens = tokenizer.encode(sampled_doc)
    num_bytes = len(sampled_doc.encode("utf-8"))
    num_tokens = len(encoded_tokens)
    bytes_per_token = num_bytes / num_tokens
    print(f"Bytes: {num_bytes}, Tokens: {num_tokens}, Ratio (bytes/token): {bytes_per_token:.2f}")

Bytes: 3426, Tokens: 1251, Ratio (bytes/token): 2.74
Bytes: 1116, Tokens: 376, Ratio (bytes/token): 2.97
Bytes: 838, Tokens: 257, Ratio (bytes/token): 3.26
Bytes: 3339, Tokens: 879, Ratio (bytes/token): 3.80
Bytes: 767, Tokens: 210, Ratio (bytes/token): 3.65
Bytes: 2438, Tokens: 712, Ratio (bytes/token): 3.42
Bytes: 1750, Tokens: 568, Ratio (bytes/token): 3.08
Bytes: 2113, Tokens: 639, Ratio (bytes/token): 3.31
Bytes: 5097, Tokens: 1518, Ratio (bytes/token): 3.36
Bytes: 3091, Tokens: 944, Ratio (bytes/token): 3.27


In [439]:
import pickle

with open("./data/owt_train_32000_vocab.pkl", "rb") as f:
    vocab = pickle.load(f)

with open("./data/owt_train_32000_merges.pkl", "rb") as f:
    merges = pickle.load(f)
    
special_tokens = ['<|endoftext|>','<|startoftext|>']

tokenizer = Tokenizer(vocab, merges, special_tokens)

with open("./data/owt_valid.txt", "r", encoding="utf-8") as f:
    data = f.read()
documents = data.split("<|endoftext|>")
documents = [doc.strip() for doc in documents if doc.strip()]
sampled_docs = random.sample(documents, 10)
num_tokens = 0
num_bytes = 0
t0 = time.perf_counter()
for sampled_doc in sampled_docs:
    encoded_tokens = tokenizer.encode(sampled_doc)
    num_bytes += len(sampled_doc.encode("utf-8"))
    num_tokens += len(encoded_tokens)
t1 = time.perf_counter()
total_time = t1-t0
token_per_s = num_tokens/total_time
byte_per_s = num_bytes/total_time
print(f"\n⏱️ 总计时间: {total_time:.2f}s")
print(f"\n⏱️ token/s: {token_per_s:.2f}tokens/s")
print(f"\n⏱️ byte/s: {byte_per_s:.2f}byte/s")


⏱️ 总计时间: 27.86s

⏱️ token/s: 408.26tokens/s

⏱️ byte/s: 1752.69byte/s


In [413]:
text = """hi. i'm yifan li. nice to meet you. <|endoftext|>
this the what when here where. <|endoftext|>
where is the car. <|endoftext|>
when is dinner"""


In [417]:
segments = tokenizer.split_text_to_segments(text)
encoded_segments, merged_segments = tokenizer.encode_segments(segments)
encoded_segments[:10]

[5016, 46, 1814, 2000, 339, 368, 273, 17125, 46, 3775]

In [418]:
tokenizer.decode(encoded_segments)

"hi. i'm yifan li. nice to meet you. <|endoftext|>\nthis the what when here where. <|endoftext|>\nwhere is the car. <|endoftext|>\nwhen is dinner"

## Check Result

In [122]:
import heapq

def top_n_longest_byte_values(d: dict, n: int):
    return heapq.nlargest(n, d.items(), key=lambda x: len(x[1]))

In [123]:
top_n_longest_byte_values(vocab,50)

[(31325, b'---------------------------'),
 (30259, b'-------------------------'),
 (28791, b'-----------------------'),
 (27301, b'---------------------'),
 (23380, b' disproportionately'),
 (24327, b' telecommunications'),
 (26042, b'-------------------'),
 (28344, b' environmentalists'),
 (31736, b' -----------------'),
 (14314, b' responsibilities'),
 (16321, b' unconstitutional'),
 (24626, b'-----------------'),
 (25755, b' cryptocurrencies'),
 (26128, b' disproportionate'),
 (27109, b' misunderstanding'),
 (28573, b' counterterrorism'),
 (29907, b'_________________'),
 (30293, b' characterization'),
 (9262, b' representatives'),
 (10287, b' recommendations'),
 (10676, b' characteristics'),
 (14155, b' straightforward'),
 (14688, b' Representatives'),
 (16411, b' internationally'),
 (19522, b' vulnerabilities'),
 (21512, b' Charlottesville'),
 (22064, b' accomplishments'),
 (23596, b' interpretations'),
 (25281, b' implementations'),
 (25289, b' representations'),
 (25679, b' exper

In [124]:
for key, value in vocab.items():
    if value == b'the':
        print(key)
        break

1113


In [133]:
(b't',b'he') in merge

True

(b' ', b't', b'h', b'e')

In [182]:

bytes_tok

(b' the',)

In [183]:
bytes_tok

(b' the',)

In [106]:
index = 0
start_len = len(bytes_tok)+1
while len(bytes_tok) != start_len:
    start_len = len(bytes_tok)
    encodered_tok = []
    index = 0
    while index < len(bytes_tok):
        if (index < len(bytes_tok)-1) and ((bytes_tok[index],bytes_tok[index+1]) in merge):
            encodered_tok.append(bytes_tok[index]+bytes_tok[index+1])
            index += 2
        else:
            encodered_tok.append(bytes_tok[index])
            index += 1
    bytes_tok = tuple(encodered_tok)
    print(bytes_tok)
bytes_tok

(b'th', b'e')
(b'th', b'e')


(b'th', b'e')

In [107]:
tok='the'
bytes_tok = tuple(bytes([b]) for b in tok.encode("utf-8"))
bytes_tok

(b't', b'h', b'e')

In [108]:
index = 0
start_len = len(bytes_tok)+1


In [109]:
len(bytes_tok) != start_len

True

In [114]:
start_len = len(bytes_tok)
encodered_tok = []
index = 0
while index < len(bytes_tok):
    if (index < len(bytes_tok)-1) and ((bytes_tok[index],bytes_tok[index+1]) in merge):
        encodered_tok.append(bytes_tok[index]+bytes_tok[index+1])
        index += 2
    else:
        encodered_tok.append(bytes_tok[index])
        index += 1
bytes_tok = tuple(encodered_tok)
bytes_tok

(b'th', b'e')

In [104]:
start_len = len(bytes_tok)
encodered_tok = []

In [120]:
(b'th', b'e') in merge

False

In [116]:
merge

[(b' ', b't'),
 (b' ', b'a'),
 (b'h', b'e'),
 (b'i', b'n'),
 (b'r', b'e'),
 (b' t', b'he'),
 (b'o', b'n'),
 (b'e', b'r'),
 (b' ', b's'),
 (b' ', b'w'),
 (b'a', b't'),
 (b' ', b'o'),
 (b'e', b'n'),
 (b' ', b'c'),
 (b'i', b't'),
 (b'i', b's'),
 (b'a', b'n'),
 (b'o', b'r'),
 (b' ', b'b'),
 (b'e', b's'),
 (b'e', b'd'),
 (b' ', b'f'),
 (b'in', b'g'),
 (b' ', b'p'),
 (b'o', b'u'),
 (b' a', b'n'),
 (b'a', b'l'),
 (b' t', b'o'),
 (b'a', b'r'),
 (b' ', b'm'),
 (b' ', b'in'),
 (b' o', b'f'),
 (b' ', b'h'),
 (b' ', b'd'),
 (b'\xe2', b'\x80'),
 (b'a', b's'),
 (b'i', b'c'),
 (b' an', b'd'),
 (b' t', b'h'),
 (b'l', b'e'),
 (b'o', b'm'),
 (b'i', b'on'),
 (b'l', b'l'),
 (b'en', b't'),
 (b' ', b'n'),
 (b' ', b'l'),
 (b' ', b're'),
 (b's', b't'),
 (b'v', b'e'),
 (b' ', b'e'),
 (b'l', b'y'),
 (b'r', b'o'),
 (b' b', b'e'),
 (b' ', b'g'),
 (b'i', b'd'),
 (b'u', b't'),
 (b'a', b'c'),
 (b'o', b't'),
 (b' ', b'T'),
 (b' ', b'I'),
 (b' th', b'at'),
 (b' ', b'on'),
 (b'a', b'y'),
 (b' ', b'S'),
 (b' ', b'is'),


## 持续优化

In [26]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280

In [27]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

string = read_text_file(input_path)
string_list = split_by_special_tokens(string,special_tokens)
token_count = count_tokens(string_list)
encoded_token_freqs = encode_and_count_tokens(token_count)

pair_freqs = get_byte_pairs(encoded_token_freqs)
pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

In [28]:
## while
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
pair_freqs_d = compute_freqs_deltas(plan, new_index)
exclude_pair_from_dict(pair_freqs, pair)
update_pair_freqs(pair_freqs, pair_freqs_d)
pair_inhereit_d = compute_inhereit_deltas(plan)
exclude_pair_from_dict(pair_inhereit, pair)
update_pair_inhereit(pair_inhereit, pair_inhereit_d)

defaultdict(set,
            {(104, 105): {(32, 116, 104, 105, 115), (104, 105)},
             (32, 105): {(32, 105), (32, 105, 115)},
             (39, 109): {(39, 109)},
             (32, 121): {(32, 121, 105, 102, 97, 110), (32, 121, 111, 117)},
             (121, 105): {(32, 121, 105, 102, 97, 110)},
             (105, 102): {(32, 121, 105, 102, 97, 110)},
             (102, 97): {(32, 121, 105, 102, 97, 110)},
             (97, 110): {(32, 121, 105, 102, 97, 110)},
             (32, 108): {(32, 108, 105)},
             (108, 105): {(32, 108, 105)},
             (32, 110): {(32, 110, 105, 99, 101)},
             (110, 105): {(32, 110, 105, 99, 101)},
             (105, 99): {(32, 110, 105, 99, 101)},
             (99, 101): {(32, 110, 105, 99, 101)},
             (32, 116): {(32, 116, 104, 105, 115),
              (32, 116, 111),
              (32, 116, 257)},
             (116, 111): {(32, 116, 111)},
             (32, 109): {(32, 109, 101, 101, 116)},
             (109, 101): {(3

In [16]:
plan

[TokenMergePlan(old_token=(32, 119, 104, 101, 110), new_token=(32, 119, 257, 110), count=2, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 104, 101, 114, 101), new_token=(32, 257, 114, 101), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 116, 104, 101), new_token=(32, 116, 257), count=2, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 119, 104, 101, 114, 101), new_token=(32, 119, 257, 114, 101), count=2, pair_positions=[2])]

[TokenMergePlan(old_token=(32, 119, 104, 101, 110), new_token=(32, 119, 257, 110), count=2, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 104, 101, 114, 101), new_token=(32, 257, 114, 101), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 116, 104, 101), new_token=(32, 116, 257), count=2, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 119, 104, 101, 114, 101), new_token=(32, 119, 257, 114, 101), count=2, pair_positions=[2])]

[(31325, b'---------------------------'),
 (30259, b'-------------------------'),
 (28791, b'-----------------------'),
 (27301, b'---------------------'),
 (23380, b' disproportionately'),
 (24327, b' telecommunications'),
 (26042, b'-------------------'),
 (28344, b' environmentalists'),
 (31736, b' -----------------'),
 (14314, b' responsibilities'),
 (16321, b' unconstitutional'),
 (24626, b'-----------------'),
 (25755, b' cryptocurrencies'),
 (26128, b' disproportionate'),
 (27109, b' misunderstanding'),
 (28573, b' counterterrorism'),
 (29907, b'_________________'),
 (30293, b' characterization'),
 (9262, b' representatives'),
 (10287, b' recommendations'),
 (10676, b' characteristics'),
 (14155, b' straightforward'),
 (14688, b' Representatives'),
 (16411, b' internationally'),
 (19522, b' vulnerabilities'),
 (21512, b' Charlottesville'),
 (22064, b' accomplishments'),
 (23596, b' interpretations'),
 (25281, b' implementations'),
 (25289, b' representations'),
 (25679, b' exper

In [14]:
top_n_longest_byte_values(vocab,20)

[(7172, b' accomplishment'),
 (9156, b' disappointment'),
 (9393, b' responsibility'),
 (3236, b' uncomfortable'),
 (3524, b' compassionate'),
 (5327, b' understanding'),
 (6401, b' neighbourhood'),
 (6512, b' Unfortunately'),
 (6888, b' determination'),
 (7769, b' encouragement'),
 (8638, b' unfortunately'),
 (8711, b' congratulated'),
 (8881, b' extraordinary'),
 (9108, b' granddaughter'),
 (256, b'<|endoftext|>'),
 (3346, b' disappointed'),
 (3777, b' enthusiastic'),
 (4368, b' accidentally'),
 (4383, b' refrigerator'),
 (4479, b' veterinarian')]

# 个人优化

## 对比最终结果

In [5]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280

In [7]:
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)
merges

[(b'h', b'e'),
 (b' ', b'w'),
 (b' w', b'he'),
 (b' ', b't'),
 (b'r', b'e'),
 (b'i', b's'),
 (b' whe', b're'),
 (b' whe', b'n'),
 (b' t', b'he'),
 (b' ', b'y'),
 (b' ', b'is'),
 (b'o', b'u'),
 (b'n', b'n'),
 (b'nn', b'e'),
 (b'nne', b'r'),
 (b'n', b'i'),
 (b'ni', b'c'),
 (b'nic', b'e'),
 (b'm', b'e'),
 (b'me', b'e'),
 (b'mee', b't'),
 (b'l', b'i'),
 (b'i', b'nner')]

## 对比每一步

In [23]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280
# string = """hi. i'm yifan li. nice to meet you. <|endoftext|>
# this the what when here where. <|endoftext|>
# where is the car. <|endoftext|>
# when is dinner"""

In [27]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

string = read_text_file(input_path)
string = string.replace("\r\n", "\n").replace("\r", "")
string_list = split_by_special_tokens(string,special_tokens)
token_count = count_tokens(string_list)
encoded_token_freqs = encode_and_count_tokens(token_count)

pair_freqs = get_byte_pairs(encoded_token_freqs)
pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

In [28]:
sorted_items = sorted(token_count.items(), key=lambda x: x[1], reverse=True)
sorted_items

[('.', 5),
 (' ', 3),
 ('"', 2),
 (' the', 2),
 (' when', 2),
 (' where', 2),
 (' is', 2),
 ('hi', 1),
 (' i', 1),
 ("'m", 1),
 (' yifan', 1),
 (' li', 1),
 (' nice', 1),
 (' to', 1),
 (' meet', 1),
 (' you', 1),
 (' this', 1),
 (' what', 1),
 (' here', 1),
 (' car', 1),
 (' dinner', 1)]

In [29]:
sorted_items = sorted(pair_freqs.items(), key=lambda x: (x[1], x[0]), reverse=True)
sorted_items

[((104, 101), 7),
 ((119, 104), 5),
 ((32, 119), 5),
 ((101, 114), 4),
 ((32, 116), 4),
 ((116, 104), 3),
 ((114, 101), 3),
 ((105, 115), 3),
 ((32, 105), 3),
 ((104, 105), 2),
 ((101, 110), 2),
 ((32, 121), 2),
 ((121, 111), 1),
 ((121, 105), 1),
 ((116, 111), 1),
 ((111, 117), 1),
 ((110, 110), 1),
 ((110, 105), 1),
 ((110, 101), 1),
 ((109, 101), 1),
 ((108, 105), 1),
 ((105, 110), 1),
 ((105, 102), 1),
 ((105, 99), 1),
 ((104, 97), 1),
 ((102, 97), 1),
 ((101, 116), 1),
 ((101, 101), 1),
 ((100, 105), 1),
 ((99, 101), 1),
 ((99, 97), 1),
 ((97, 116), 1),
 ((97, 114), 1),
 ((97, 110), 1),
 ((39, 109), 1),
 ((32, 110), 1),
 ((32, 109), 1),
 ((32, 108), 1),
 ((32, 104), 1),
 ((32, 100), 1),
 ((32, 99), 1)]

In [30]:
pair_inhereit

defaultdict(set,
            {(104, 105): {(32, 116, 104, 105, 115), (104, 105)},
             (32, 105): {(32, 105), (32, 105, 115)},
             (39, 109): {(39, 109)},
             (32, 121): {(32, 121, 105, 102, 97, 110), (32, 121, 111, 117)},
             (121, 105): {(32, 121, 105, 102, 97, 110)},
             (105, 102): {(32, 121, 105, 102, 97, 110)},
             (102, 97): {(32, 121, 105, 102, 97, 110)},
             (97, 110): {(32, 121, 105, 102, 97, 110)},
             (32, 108): {(32, 108, 105)},
             (108, 105): {(32, 108, 105)},
             (32, 110): {(32, 110, 105, 99, 101)},
             (110, 105): {(32, 110, 105, 99, 101)},
             (105, 99): {(32, 110, 105, 99, 101)},
             (99, 101): {(32, 110, 105, 99, 101)},
             (32, 116): {(32, 116, 104, 101),
              (32, 116, 104, 105, 115),
              (32, 116, 111)},
             (116, 111): {(32, 116, 111)},
             (32, 109): {(32, 109, 101, 101, 116)},
             (109, 101)

In [37]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)

In [40]:
pair_freqs_d = compute_freqs_deltas(plan, new_index)
pair_freqs_d

<__main__.PairFreqsDelta at 0x1278624a0>

In [41]:
pair_freqs_d.dec

defaultdict(int,
            {(119, 104): 4,
             (101, 110): 2,
             (32, 104): 1,
             (101, 114): 3,
             (116, 104): 2})

In [92]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
encoded_token_freqs = update_encoded_token_freqs(plan, encoded_token_freqs)
pair_freqs_d = compute_freqs_deltas(plan, new_index)
pair_freqs = exclude_pair_from_dict(pair_freqs, pair)
pair_freqs = update_pair_freqs(pair_freqs, pair_freqs_d)
pair_inhereit_d = compute_inhereit_deltas(plan)
pair_inhereit = exclude_pair_from_dict(pair_inhereit, pair)
pair_inhereit = update_pair_inhereit(pair_inhereit, pair_inhereit_d)
vocab_len+=1

In [51]:
encoded_token_freqs

defaultdict(int,
            {(104, 105): 1,
             (46,): 5,
             (32, 105): 1,
             (39, 109): 1,
             (32, 121, 105, 102, 97, 110): 1,
             (32, 108, 105): 1,
             (32, 110, 105, 99, 101): 1,
             (32, 116, 111): 1,
             (32, 109, 101, 101, 116): 1,
             (32, 121, 111, 117): 1,
             (32,): 3,
             (32, 116, 104, 105, 115): 1,
             (32, 119, 104, 97, 116): 1,
             (32, 105, 115): 2,
             (32, 99, 97, 114): 1,
             (32, 100, 105, 110, 110, 101, 114): 1,
             (32, 119, 257, 110): 2,
             (32, 257, 114, 101): 1,
             (32, 116, 257): 2,
             (32, 119, 257, 114, 101): 2})

In [None]:
pair_freqs

## debug

In [82]:
special_tokens = ['<|endoftext|>']
input_path = fr'./data/test.txt'
vocab_size = 280
string = """hi. i'm yifan li. nice to meet you. <|endoftext|>
this the what when here where. <|endoftext|>
where is the car. <|endoftext|>
when is dinner"""

In [83]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

# string = read_text_file(input_path)
string_list = split_by_special_tokens(string,special_tokens)
token_count = count_tokens(string_list)
encoded_token_freqs = encode_and_count_tokens(token_count)

pair_freqs = get_byte_pairs(encoded_token_freqs)
pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

In [84]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
encoded_token_freqs = update_encoded_token_freqs(plan, encoded_token_freqs)
pair_freqs_d = compute_freqs_deltas(plan, new_index)
# pair_freqs = exclude_pair_from_dict(pair_freqs, pair)
# pair_freqs = update_pair_freqs(pair_freqs, pair_freqs_d)
# pair_inhereit_d = compute_inhereit_deltas(plan)
# pair_inhereit = exclude_pair_from_dict(pair_inhereit, pair)
# pair_inhereit = update_pair_inhereit(pair_inhereit, pair_inhereit_d)
# vocab_len+=1

In [85]:
plan

[TokenMergePlan(old_token=(119, 104, 101, 114, 101), new_token=(119, 257, 114, 101), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 119, 104, 101, 114, 101), new_token=(32, 119, 257, 114, 101), count=1, pair_positions=[2]),
 TokenMergePlan(old_token=(119, 104, 101, 110), new_token=(119, 257, 110), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 119, 104, 101, 110), new_token=(32, 119, 257, 110), count=1, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 104, 101, 114, 101), new_token=(32, 257, 114, 101), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 116, 104, 101), new_token=(32, 116, 257), count=2, pair_positions=[2])]

In [86]:
pair_freqs_d.inc

defaultdict(int,
            {(119, 257): 4,
             (257, 114): 3,
             (257, 110): 2,
             (32, 257): 1,
             (116, 257): 2})

In [None]:
def build_merge_plan(tok_need_update: set[Encoded_Token], encoded_token_freqs: dict[Encoded_Token, int], pair: Pair, new_index: int) -> list[TokenMergePlan]:
    plan = []
    for encoded_token in tok_need_update:
        new_encoded_token = update_encoded_token(encoded_token, pair, new_index)
        count = encoded_token_freqs[encoded_token]
        pair_positions = find_subtuple_index(encoded_token,pair)
        plan.append(TokenMergePlan(encoded_token,new_encoded_token,count,pair_positions))
    return plan

In [21]:
def find_diff(A,B):
    # 1. A 有但 B 没有的 key 和值
    only_in_A = {k: A[k] for k in A.keys() - B.keys()}

    # 2. B 有但 A 没有的 key 和值
    only_in_B = {k: B[k] for k in B.keys() - A.keys()}

    # 3. A 和 B 都有，但值不同的 key 和值
    diff_values = {k: (A[k], B[k]) for k in A.keys() & B.keys() if A[k] != B[k]}
    return only_in_A, only_in_B, diff_values

In [30]:
plan

[TokenMergePlan(old_token=(32, 119, 104, 101, 110), new_token=(32, 119, 257, 110), count=2, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 104, 101, 114, 101), new_token=(32, 257, 114, 101), count=1, pair_positions=[1]),
 TokenMergePlan(old_token=(32, 116, 104, 101), new_token=(32, 116, 257), count=2, pair_positions=[2]),
 TokenMergePlan(old_token=(32, 119, 104, 101, 114, 101), new_token=(32, 119, 257, 114, 101), count=2, pair_positions=[2])]

In [22]:
only_in_A, only_in_B, diff_values = find_diff(pair_inhereit,pair_inhereit2)

In [29]:
only_in_A

{}

In [26]:
only_in_B

{(257, 114): {(32, 119, 257, 114, 101), (32, 257, 114, 101)},
 (116, 257): {(32, 116, 257)},
 (257, 110): {(32, 119, 257, 110)},
 (119, 257): {(32, 119, 257, 110), (32, 119, 257, 114, 101)},
 (32, 257): {(32, 257, 114, 101)}}

In [27]:
diff_values

{(32, 116): ({(32, 116, 104, 105, 115), (32, 116, 111)},
  {(32, 116, 104, 105, 115), (32, 116, 111), (32, 116, 257)}),
 (32, 119): ({(32, 119, 104, 97, 116)},
  {(32, 119, 104, 97, 116), (32, 119, 257, 110), (32, 119, 257, 114, 101)}),
 (104,
  101): ({(32, 104, 101, 114, 101),
   (32, 116, 104, 101),
   (32, 119, 104, 101, 110),
   (32, 119, 104, 101, 114, 101)}, set()),
 (114, 101): (set(), {(32, 119, 257, 114, 101), (32, 257, 114, 101)})}

In [36]:
pair_inhereit[(32,116)]

{(32, 116, 104, 101), (32, 116, 104, 105, 115), (32, 116, 111)}

In [43]:
pair_inhereit2[(32,119)]

{(32, 119, 104, 97, 116), (32, 119, 257, 110), (32, 119, 257, 114, 101)}

## Unit test

In [45]:
input_path = r"./tests/fixtures/tinystories_sample_5M.txt"

In [48]:
vocab_size = 1000
special_tokens=["<|endoftext|>"]
vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

In [49]:
import pickle

with open("./tests/_snapshots/test_train_bpe_special_tokens.pkl", "rb") as f:
    data = pickle.load(f)
    
data.keys()

dict_keys(['vocab_keys', 'vocab_values', 'merges'])

In [50]:
for i, (my_merge, target_merge) in enumerate(zip(merges,data['merges'])):
    if not my_merge == target_merge:
        print([i, my_merge, target_merge])
        break

In [12]:
merges

[(b'h', b'e'),
 (b' ', b't'),
 (b' ', b'a'),
 (b't', b'h'),
 (b' ', b's'),
 (b'a', b'n'),
 (b' ', b'w'),
 (b'n', b'd'),
 (b' ', b'h'),
 (b'e', b'd'),
 (b't', b'o'),
 (b'e', b'r'),
 (b' ', b'b'),
 (b'i', b'n'),
 (b'w', b'a'),
 (b' ', b'f'),
 (b'r', b'e'),
 (b' ', b'T'),
 (b'i', b't'),
 (b'o', b'u'),
 (b'a', b's'),
 (b'h', b'a'),
 (b' ', b'l'),
 (b' ', b'd'),
 (b' ', b'i'),
 (b'e', b'n'),
 (b' ', b'c'),
 (b' ', b'p'),
 (b'a', b'y'),
 (b'a', b'r'),
 (b'm', b'e'),
 (b' ', b'm'),
 (b'o', b'm'),
 (b'v', b'e'),
 (b'a', b't'),
 (b' ', b'o'),
 (b'l', b'e'),
 (b's', b'a'),
 (b'h', b'i'),
 (b'n', b'e'),
 (b'n', b'g'),
 (b' ', b'n'),
 (b'i', b'm'),
 (b'l', b'l'),
 (b't', b'e'),
 (b'i', b'd'),
 (b'a', b'l'),
 (b'k', b'e'),
 (b'T', b'h'),
 (b'o', b'r'),
 (b'i', b's'),
 (b' ', b'g'),
 (b' ', b'S'),
 (b's', b't'),
 (b'e', b'a'),
 (b'c', b'a'),
 (b'o', b'o'),
 (b'l', b'a'),
 (b'l', b'i'),
 (b'l', b'o'),
 (b'i', b'l'),
 (b'o', b't'),
 (b'n', b't'),
 (b'a', b'i'),
 (b'a', b'd'),
 (b'o', b'n'),
 (b'u', b'

In [13]:
data['merges']

[(b'h', b'e'),
 (b' ', b't'),
 (b' ', b'a'),
 (b' ', b's'),
 (b' ', b'w'),
 (b'n', b'd'),
 (b' t', b'he'),
 (b'e', b'd'),
 (b' ', b'b'),
 (b' t', b'o'),
 (b' a', b'nd'),
 (b' ', b'h'),
 (b' ', b'f'),
 (b'i', b'n'),
 (b' w', b'a'),
 (b' ', b'T'),
 (b'i', b't'),
 (b'r', b'e'),
 (b'o', b'u'),
 (b' ', b'l'),
 (b' ', b'd'),
 (b' ', b'c'),
 (b' ', b'p'),
 (b'a', b'y'),
 (b' wa', b's'),
 (b'e', b'r'),
 (b' ', b'm'),
 (b'o', b'm'),
 (b' ', b'he'),
 (b' T', b'he'),
 (b'i', b's'),
 (b' ', b'n'),
 (b'o', b'n'),
 (b'a', b'r'),
 (b'i', b'm'),
 (b' s', b'a'),
 (b'l', b'l'),
 (b'i', b'd'),
 (b' h', b'a'),
 (b' ', b'g'),
 (b' ', b'S'),
 (b'a', b't'),
 (b'in', b'g'),
 (b'o', b't'),
 (b'e', b'n'),
 (b'a', b'n'),
 (b'l', b'e'),
 (b'o', b'r'),
 (b'i', b'r'),
 (b' ', b'H'),
 (b'a', b'm'),
 (b'e', b't'),
 (b' ', b'it'),
 (b' t', b'h'),
 (b'i', b'g'),
 (b' The', b'y'),
 (b'i', b'l'),
 (b' ', b'in'),
 (b' H', b'e'),
 (b' p', b'l'),
 (b' ', b'"'),
 (b'o', b'w'),
 (b'v', b'er'),
 (b'r', b'i'),
 (b' ', b'u'),
 (

In [30]:
input_path = r"./tests/fixtures/tinystories_sample_5M.txt"
vocab_size=1000
special_tokens=["<|endoftext|>"]

In [31]:
vocab = build_initial_vocab(special_tokens)
vocab_len = len(vocab)
merges = []

string = read_text_file(input_path)
string_list = split_by_special_tokens(string,special_tokens)
token_count = count_tokens(string_list)
encoded_token_freqs = encode_and_count_tokens(token_count)

pair_freqs = get_byte_pairs(encoded_token_freqs)
pair_inhereit = get_byte_pairs_inhereit(encoded_token_freqs)

In [None]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]
new_index = vocab_len
vocab[new_index] = new_token
merges.append((vocab[int(index1)], vocab[int(index2)]))
tok_need_update = pair_inhereit[pair]
plan = build_merge_plan(tok_need_update, encoded_token_freqs, pair, new_index)
encoded_token_freqs = update_encoded_token_freqs(plan, encoded_token_freqs)
pair_freqs_d, pair_inhereit_d = compute_freqs_and_inhereit_deltas(plan, new_index)
pair_freqs = exclude_pair_from_dict(pair_freqs, pair)
pair_freqs = update_pair_freqs(pair_freqs, pair_freqs_d)
pair_inhereit = exclude_pair_from_dict(pair_inhereit, pair)
pair_inhereit = update_pair_inhereit(pair_inhereit, pair_inhereit_d)
vocab_len+=1

In [33]:
tok_need_update

{(32, 65, 110, 111, 116, 104, 101, 114),
 (32, 66, 114, 111, 116, 104, 101, 114),
 (32, 67, 104, 101, 102),
 (32, 67, 104, 101, 114, 114, 121),
 (32, 69, 118, 101, 114, 121, 119, 104, 101, 114, 101),
 (32, 70, 117, 114, 116, 104, 101, 114, 109, 111, 114, 101),
 (32, 76, 111, 110, 103, 104, 101, 97, 100),
 (32, 78, 101, 105, 116, 104, 101, 114),
 (32, 79, 116, 104, 101, 114),
 (32, 82, 97, 99, 104, 101, 108),
 (32, 83, 104, 101),
 (32, 83, 104, 101, 108, 108),
 (32, 84, 104, 101),
 (32, 84, 104, 101, 105, 114),
 (32, 84, 104, 101, 110),
 (32, 84, 104, 101, 114, 101),
 (32, 84, 104, 101, 115, 101),
 (32, 84, 104, 101, 121),
 (32, 84, 111, 103, 101, 116, 104, 101, 114),
 (32, 87, 104, 101, 101),
 (32, 87, 104, 101, 110),
 (32, 87, 104, 101, 110, 101, 118, 101, 114),
 (32, 87, 104, 101, 114, 101),
 (32, 97, 97, 104, 101, 100),
 (32, 97, 99, 99, 111, 109, 112, 108, 105, 115, 104, 101, 100),
 (32, 97, 99, 104, 101),
 (32, 97, 104, 101, 97, 100),
 (32, 97, 110, 111, 116, 104, 101, 114),
 (32,

In [25]:
pair = select_merge_pair(pair_freqs, vocab)
index1, index2 = pair
new_token = vocab[int(index1)]+vocab[int(index2)]

In [28]:
pair_freqs[pair]

101570

In [29]:
for key, value in pair_freqs.items():
    if value>= 101570:
        print(key)

(116, 104)


## Q1.2 Problem (train_bpe_tinystories): BPE Training on TinyStories 

In [None]:
%load_ext memory_profiler

In [14]:
import time

In [21]:
input_path = r"./data/TinyStoriesV2-GPT4-valid.txt"
special_tokens = ['<|endoftext|>']
vocab_size = 32000

In [22]:
start_time = time.time()
%memit vocab, merges = train_bpe(input_path, vocab_size, special_tokens)
end_time = time.time()
print(f"⏱️ 运行时间: {end_time - start_time:.2f} 秒")

peak memory: 356.20 MiB, increment: 23.55 MiB
⏱️ 运行时间: 85.28 秒


In [25]:
input_path = r'./data/TinyStoriesV2-GPT4-train.txt'

extracted

'TinyStoriesV2-GPT4-train'