In [1]:
low = ["low" for i in range(5)]
lower = ["lower" for i in range(2)]
widest = ["widest" for i in range(3)]
newest = ["newest" for i in range(6)]
text = low + lower + widest + newest

In [2]:
text

['low',
 'low',
 'low',
 'low',
 'low',
 'lower',
 'lower',
 'widest',
 'widest',
 'widest',
 'newest',
 'newest',
 'newest',
 'newest',
 'newest',
 'newest']

In [3]:
from typing import List, Tuple, Dict
import os
from multiprocessing import Pool,cpu_count
import regex as re
from cs336_basics.pretokenization_example import find_chunk_boundaries
import time
import json
from collections import Counter,defaultdict
import heapq

def process_single_chunk(text):
        words_to_count = Counter()

        for t in text:

            words_to_count[t] += 1
        
        words_to_tokens = {}
        pair_to_words = defaultdict(set)
        pair_to_count = Counter()
        for word, count in words_to_count.items():
            word_encoded = word.encode("utf-8")
            l_list = [bytes([b]) for b in word_encoded]
            
            words_to_tokens[word] = l_list
            if len(l_list) >= 2:
                for i in range(len(l_list) - 1):
                    pair = (l_list[i], l_list[i + 1])
                    pair_to_words[pair].add(word)
                    pair_to_count[pair] += count
        return {
            "words_to_count": words_to_count,
            "words_to_tokens": words_to_tokens,
            "pair_to_words": pair_to_words,
            "pair_to_count": pair_to_count
        }
        
def pre_token(text):
    chunk_results = []
    chunk_results.append(process_single_chunk(text))
    words_to_count = Counter()
    pair_to_count = Counter()
    words_to_tokens = {}
    pair_to_words = defaultdict(set)
    for chunk_result in chunk_results:
        words_to_count.update(chunk_result["words_to_count"])
        pair_to_count.update(chunk_result["pair_to_count"])
        words_to_tokens.update(chunk_result["words_to_tokens"])
        for pair, word_set in chunk_result["pair_to_words"].items():
            pair_to_words[pair].update(word_set)

    return {
            "words_to_count": words_to_count,
            "words_to_tokens": words_to_tokens,
            "pair_to_words": pair_to_words,
            "pair_to_count": pair_to_count
        }

In [4]:
ans = pre_token(text)

In [5]:
ans

{'words_to_count': Counter({'newest': 6, 'low': 5, 'widest': 3, 'lower': 2}),
 'words_to_tokens': {'low': [b'l', b'o', b'w'],
  'lower': [b'l', b'o', b'w', b'e', b'r'],
  'widest': [b'w', b'i', b'd', b'e', b's', b't'],
  'newest': [b'n', b'e', b'w', b'e', b's', b't']},
 'pair_to_words': defaultdict(set,
             {(b'l', b'o'): {'low', 'lower'},
              (b'o', b'w'): {'low', 'lower'},
              (b'w', b'e'): {'lower', 'newest'},
              (b'e', b'r'): {'lower'},
              (b'w', b'i'): {'widest'},
              (b'i', b'd'): {'widest'},
              (b'd', b'e'): {'widest'},
              (b'e', b's'): {'newest', 'widest'},
              (b's', b't'): {'newest', 'widest'},
              (b'n', b'e'): {'newest'},
              (b'e', b'w'): {'newest'}}),
 'pair_to_count': Counter({(b'e', b's'): 9,
          (b's', b't'): 9,
          (b'w', b'e'): 8,
          (b'l', b'o'): 7,
          (b'o', b'w'): 7,
          (b'n', b'e'): 6,
          (b'e', b'w'): 6,
       

In [9]:
class ReverseSortPair:
    def __init__(self, pair):
        self.pair = pair
    def __lt__(self, other):
        # 频率相同时，字典序大的返回 True (即它更“小”，会排在堆顶)
        return self.pair > other.pair

def train_bpe(text, vocab_size: int, special_tokens: List[str]) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    pre_token_result_dict = pre_token(text)
    w2c = pre_token_result_dict["words_to_count"]
    w2t = pre_token_result_dict["words_to_tokens"]
    p2w = pre_token_result_dict["pair_to_words"]
    p2c = pre_token_result_dict["pair_to_count"]
    
    # 先放special token（从ID 0开始），然后是256个字节
    vocab = {}
    for i, special_token in enumerate(special_tokens):
        vocab[i] = special_token.encode("utf-8")
    for i in range(256):
        vocab[len(special_tokens) + i] = bytes([i])
    cur_vocab_size = 256 + len(special_tokens)
    heap = [(-count, ReverseSortPair(pair)) for pair, count in p2c.items() if count > 0]
    heapq.heapify(heap)
    merges = []
    while cur_vocab_size < vocab_size:
        if not heap:
            break
        
        neg_count, wrapper = heapq.heappop(heap)
        best_pair = wrapper.pair
        count = - neg_count
        if count != p2c.get(best_pair,0):
            continue
        merges.append(best_pair)
        words_to_update = list(p2w[best_pair])
        new_token = best_pair[0] + best_pair[1]
        vocab[cur_vocab_size] = new_token
        cur_vocab_size += 1
        for word in words_to_update:
            new_tokens = []
            old_tokens = w2t[word]
            i = 0
            while i < len(old_tokens):
                if i < len(old_tokens) -1 and (old_tokens[i], old_tokens[i + 1]) == best_pair:
                    new_tokens.append(new_token)
                    if i > 0:
                        # (old_token[i-1], old_token[i]) 这对pair 的count减去一个word count
                        # p2w里面移除掉(old_token[i-1], old_token[i])对应的word
                        # 生成新pair: (old_token[i], new_token)
                        old_near_pair = (old_tokens[i - 1], old_tokens[i])
                        p2c[old_near_pair] -= w2c[word]
                        new_near_pair = (old_tokens[i -1], new_token)
                        p2w[new_near_pair].add(word)
                        if word in p2w[old_near_pair]:
                            p2w[old_near_pair].remove(word)
                        p2c[new_near_pair] += w2c[word]
                        heapq.heappush(heap,(-p2c[new_near_pair], ReverseSortPair(new_near_pair)))
                    if i < len(old_tokens) - 2:
                        if (old_tokens[i + 1], old_tokens[i + 2]) != best_pair:
                            old_near_pair = (old_tokens[i + 1], old_tokens[i + 2])
                            p2c[old_near_pair] -= w2c[word]
                            new_near_pair = (new_token, old_tokens[i + 2])
                            p2w[new_near_pair].add(word)
                            if word in p2w[old_near_pair]:
                                p2w[old_near_pair].remove(word)
                            p2c[new_near_pair] += w2c[word]
                            heapq.heappush(heap,(-p2c[new_near_pair], ReverseSortPair(new_near_pair)))
                    i += 2
                else:
                    new_tokens.append(old_tokens[i])
                    i += 1
            w2t[word] = new_tokens
        del p2w[best_pair]
        p2c[best_pair] = 0
    
    return vocab, merges

In [10]:
special_tokens = ["<|endoftext|>"]
train_bpe(text, 263,special_tokens)

({0: b'<|endoftext|>',
  1: b'\x00',
  2: b'\x01',
  3: b'\x02',
  4: b'\x03',
  5: b'\x04',
  6: b'\x05',
  7: b'\x06',
  8: b'\x07',
  9: b'\x08',
  10: b'\t',
  11: b'\n',
  12: b'\x0b',
  13: b'\x0c',
  14: b'\r',
  15: b'\x0e',
  16: b'\x0f',
  17: b'\x10',
  18: b'\x11',
  19: b'\x12',
  20: b'\x13',
  21: b'\x14',
  22: b'\x15',
  23: b'\x16',
  24: b'\x17',
  25: b'\x18',
  26: b'\x19',
  27: b'\x1a',
  28: b'\x1b',
  29: b'\x1c',
  30: b'\x1d',
  31: b'\x1e',
  32: b'\x1f',
  33: b' ',
  34: b'!',
  35: b'"',
  36: b'#',
  37: b'$',
  38: b'%',
  39: b'&',
  40: b"'",
  41: b'(',
  42: b')',
  43: b'*',
  44: b'+',
  45: b',',
  46: b'-',
  47: b'.',
  48: b'/',
  49: b'0',
  50: b'1',
  51: b'2',
  52: b'3',
  53: b'4',
  54: b'5',
  55: b'6',
  56: b'7',
  57: b'8',
  58: b'9',
  59: b':',
  60: b';',
  61: b'<',
  62: b'=',
  63: b'>',
  64: b'?',
  65: b'@',
  66: b'A',
  67: b'B',
  68: b'C',
  69: b'D',
  70: b'E',
  71: b'F',
  72: b'G',
  73: b'H',
  74: b'I',
  75: b'