In [1]:
from basic_tokenizer import BasicTokenizer
from regex_tokenizer import RegexTokenizer
import tiktoken


### Part1. test the function of BasicTokenizer and RegexTokenizer

In [2]:
# prepare data
with open('toy.txt', 'r') as f:
    toy_text = f.read()

print(len(toy_text))

503


In [7]:
# hyper-parameters:
vocab_size = 256 + 128

In [16]:
# create a model and train a BasicTokenizer
basic_tokenizer = BasicTokenizer()
basic_tokenizer.train(toy_text, vocab_size)
# for p, idx in tokenizer.merges.items():
#     print(f'{p} => {idx}\n')

In [17]:
# simple test the function of BasicTokenizer
encode = basic_tokenizer.encode
decode = basic_tokenizer.decode
print(decode(encode("come on rk")))

come on rk


In [9]:
# create a model and train a RegexTokenizer
test_text = toy_text
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
re_tokenizer = RegexTokenizer(GPT4_SPLIT_PATTERN)
re_tokenizer.train(test_text, vocab_size, verbose=False)

In [10]:
# simple test the function of RegexTokenizer
re_encode = re_tokenizer.encode
re_decode = re_tokenizer.decode
re_decode(re_encode('hello word! this is hello from rk'))

'hello word! this is hello from rk'

### Part2. compare RegexTokenizer and GPT4 tiktoken

In [13]:
text = 'hello world!!!? (안녕하세요!) lol123 😉'
enc_gpt4 = tiktoken.get_encoding('cl100k_base')


In [19]:
ids_basic = basic_tokenizer.encode(text)
ids_regex = re_tokenizer.encode(text)
ids_gpt4 = enc_gpt4.encode(text)

text_basic = basic_tokenizer.decode(ids_basic)
text_gpt4 = enc_gpt4.decode(ids_gpt4)
text_regex = re_tokenizer.decode(ids_regex)
text_gpt4 == text_regex == text_basic

True

### Part3. test the rebuild GPT4 tokenizer

In [53]:
# helper function used in recover_merges()
def bpe(mergable_ranks, token, max_rank):
    parts = [bytes([b]) for b in token]
    while True:
        min_idx = min_rank = None
        # print(parts)
        # find the pair that has the lowest rank
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
        
        # if len(parts) == 2, stop loop
        if min_rank == max_rank:
            break
        assert min_idx is not None
        # otherwise, continue the merge loop
        parts = parts[:min_idx] + \
                [parts[min_idx] + parts[min_idx + 1]] + \
                parts[min_idx + 2:]
    
    return parts

In [54]:
# recover the dictionary of merges
def recover_merges(mergeable_ranks):
    # the `merges` are already the byte sequences in their merged state.
    # so we have to recover the original pairings. We can do this by doing
    # a small BPE training run on all the tokens, in their order.
    merges = {}
    for token, rank in mergeable_ranks.items():
        if len(token) == 1:
            continue # skip raw bytes
        pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
        assert len(pair) == 2
        
        # recover the integer ranks of the pair
        ix0 = mergeable_ranks[pair[0]]
        ix1 = mergeable_ranks[pair[1]]
        merges[(ix0, ix1)] = rank

    return merges

In [82]:
def bytes_to_unicode():
    bs = [b for b in range(256) if chr(b).isprintable() and chr(b) != ' ']
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def build_shuffled_vocab(merges):
    vocab = {}
    byte_unicode = bytes_to_unicode()
    for i, bi in enumerate(byte_unicode):
        vocab[i] = bytes([bi])
    for (p0, p1), idx in merges.items():
        vocab[idx] = vocab[p0] + vocab[p1]
    return vocab

In [83]:
import tiktoken
enc = tiktoken.get_encoding('cl100k_base')
mergeable_ranks = enc._mergeable_ranks
merges = recover_merges(mergeable_ranks)
vocab1 = build_shuffled_vocab(merges)
vocab2 = {v: k for k, v in mergeable_ranks.items()}


In [84]:
print(len(vocab1), len(vocab2))

100256 100256


In [None]:
x = 0
for i in range(100256):
    if vocab1[i] == vocab2[i]:
        x += 1
print(x)

100256


In [87]:
i = 0
for k, v in vocab1.items():
    if i < 10:
        print(k,  v)
    i += 1

0 b'!'
1 b'"'
2 b'#'
3 b'$'
4 b'%'
5 b'&'
6 b"'"
7 b'('
8 b')'
9 b'*'


In [None]:
# ## view the contents in the dictionary '_mergeable_ranks' of gpt4

# with open('gpt4_mergeable_ranks.txt', 'w') as f:
#     for i, (k, v) in enumerate(mergeable_ranks.items()):
#         if (i + 1) % 8 == 0:
#             f.write(f'{k.ljust(16)}: {v:6d}\n')
#         else:
#             f.write(f'{k.ljust(16)}: {v:6d} ')

In [34]:
for i, (k, v) in enumerate(mergeable_ranks.items()):
    if (i + 1) % 12 == 0:
        print(f'{k}: {v:3d}')
    elif i >= 256:
        break
    else:
        print(f'{k}: {v:3d}', end = '| ')

b'!':   0| b'"':   1| b'#':   2| b'$':   3| b'%':   4| b'&':   5| b"'":   6| b'(':   7| b')':   8| b'*':   9| b'+':  10| b',':  11
b'-':  12| b'.':  13| b'/':  14| b'0':  15| b'1':  16| b'2':  17| b'3':  18| b'4':  19| b'5':  20| b'6':  21| b'7':  22| b'8':  23
b'9':  24| b':':  25| b';':  26| b'<':  27| b'=':  28| b'>':  29| b'?':  30| b'@':  31| b'A':  32| b'B':  33| b'C':  34| b'D':  35
b'E':  36| b'F':  37| b'G':  38| b'H':  39| b'I':  40| b'J':  41| b'K':  42| b'L':  43| b'M':  44| b'N':  45| b'O':  46| b'P':  47
b'Q':  48| b'R':  49| b'S':  50| b'T':  51| b'U':  52| b'V':  53| b'W':  54| b'X':  55| b'Y':  56| b'Z':  57| b'[':  58| b'\\':  59
b']':  60| b'^':  61| b'_':  62| b'`':  63| b'a':  64| b'b':  65| b'c':  66| b'd':  67| b'e':  68| b'f':  69| b'g':  70| b'h':  71
b'i':  72| b'j':  73| b'k':  74| b'l':  75| b'm':  76| b'n':  77| b'o':  78| b'p':  79| b'q':  80| b'r':  81| b's':  82| b't':  83
b'u':  84| b'v':  85| b'w':  86| b'x':  87| b'y':  88| b'z':  89| b'{':  90| b'|':

In [24]:
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

In [27]:
# {int: str}
byte_code = bytes_to_unicode()
print(byte_code)
for i, (byte, code) in enumerate(byte_code.items()):
    if (i + 1) % 12 == 0:
        print(f'{byte}: {code}')
    else:
        print(f'{byte}: {code}', end=' ')

33: ! 34: " 35: # 36: $ 37: % 38: & 39: ' 40: ( 41: ) 42: * 43: + 44: ,
45: - 46: . 47: / 48: 0 49: 1 50: 2 51: 3 52: 4 53: 5 54: 6 55: 7 56: 8
57: 9 58: : 59: ; 60: < 61: = 62: > 63: ? 64: @ 65: A 66: B 67: C 68: D
69: E 70: F 71: G 72: H 73: I 74: J 75: K 76: L 77: M 78: N 79: O 80: P
81: Q 82: R 83: S 84: T 85: U 86: V 87: W 88: X 89: Y 90: Z 91: [ 92: \
93: ] 94: ^ 95: _ 96: ` 97: a 98: b 99: c 100: d 101: e 102: f 103: g 104: h
105: i 106: j 107: k 108: l 109: m 110: n 111: o 112: p 113: q 114: r 115: s 116: t
117: u 118: v 119: w 120: x 121: y 122: z 123: { 124: | 125: } 126: ~ 161: ¡ 162: ¢
163: £ 164: ¤ 165: ¥ 166: ¦ 167: § 168: ¨ 169: © 170: ª 171: « 172: ¬ 174: ® 175: ¯
176: ° 177: ± 178: ² 179: ³ 180: ´ 181: µ 182: ¶ 183: · 184: ¸ 185: ¹ 186: º 187: »
188: ¼ 189: ½ 190: ¾ 191: ¿ 192: À 193: Á 194: Â 195: Ã 196: Ä 197: Å 198: Æ 199: Ç
200: È 201: É 202: Ê 203: Ë 204: Ì 205: Í 206: Î 207: Ï 208: Ð 209: Ñ 210: Ò 211: Ó
212: Ô 213: Õ 214: Ö 215: × 216: Ø 217: Ù 218: Ú 219: Û 220