# 自作LLM (Tokenizer編)

## byte pair encoding 

Suppose the data to be encoded is

```
aaabdaaabac
```

11 token. 

The byte pair "aa" occurs most often, so it will be replaced by a byte that is not used in the data, such as "Z". Now there is the following data and replacement table:

```
ZabdZabac
Z=aa
```
Then the process is repeated with byte pair "ab", replacing it with "Y":

```
ZYdZYac
Y=ab
Z=aa
```
The only literal byte pair left occurs only once, and the encoding might stop here. Alternatively, the process could continue with recursive byte pair encoding, replacing "ZY" with "X":

```
XdXac
X=ZY
Y=ab
Z=aa
```

5 token 

This data cannot be compressed further by byte pair encoding because there are no pairs of bytes that occur more than once.

To decompress the data, simply perform the replacements in the reverse order.

[参考](https://en.wikipedia.org/wiki/Byte_pair_encoding)

In [1]:
# download text from wikipedia to make a dataset for training a language model
import wikipedia as wiki
import os
import glob

wiki.set_lang("en")
en_topics: list = ["Python (programming language)", "Attention Is All You Need","Harry Potter", "The Big Bang Theory"]

for topic in en_topics:
    try:
        if os.path.exists("data/{}_en.txt".format(topic.replace(" ", "_"))):
            print("Skipping \"{}\" as it already exists".format(topic))
            continue
        page = wiki.page(topic, auto_suggest=False)
        content = page.content
        os.makedirs("data", exist_ok=True)
        with open("data/{}_en.txt".format(topic.replace(" ", "_")), "w") as f:
            f.write(content)
        print("Downloaded \"{}\"".format(topic))
    except:
        print("Failed to download \"{}\"".format(topic))
        continue

# wiki.set_lang("jp")
# jp_topics: list = ["Python", "アテンション (機械学習)","ハリー・ポッターシリーズ", "ビッグバン★セオリー/ギークなボクらの恋愛法則"]

# for topic in jp_topics:
#     try:
#         if os.path.exists("data/{}_jp.txt".format(topic.replace(" ", "_"))):
#             print("Skipping \"{}\" as it already exists".format(topic))
#             continue
#         page = wiki.page(topic, auto_suggest=False)
#         content = page.content
#         os.makedirs("data", exist_ok=True)
#         with open("data/{}_jp.txt".format(topic.replace(" ", "_")), "w") as f:
#             f.write(content)
#         print("Downloaded \"{}\"".format(topic))
#     except:
#         print("Failed to download \"{}\"".format(topic))
#         continue

data_paths = glob.glob("data/*.txt")
training_data_path = os.path.join("./","training_data.txt")

with open(training_data_path, "w") as f:
    for path in data_paths:
        with open(path, "r") as f2:
            content = f2.read()
            f.write(content)

with open(training_data_path, "r") as f:
    full_content = f.read()
    print("File: {} has {} characters".format(training_data_path, len(full_content)))



Skipping "Python (programming language)" as it already exists
Skipping "Attention Is All You Need" as it already exists
Skipping "Harry Potter" as it already exists
Skipping "The Big Bang Theory" as it already exists
File: ./training_data.txt has 151092 characters


In [2]:
# translate the text to unicode
unicode_text = full_content.encode("utf-8")
print("number of unicode characters: {} characters".format(len(unicode_text)))


number of unicode characters: 151243 characters


[Radford, Alec, et al. "Language models are unsupervised multitask learners." OpenAI blog 1.8 (2019): 9.](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)

```
Byte Pair Encoding (BPE) (Sennrich et al., 2015) is a practical middle ground between character and word level language modeling which effectively interpolates between word level inputs for frequent symbol sequences and char- acter level inputs for infrequent symbol sequences. Despite its name, reference BPE implementations often operate on Unicode code points and not byte sequences. These imple- mentations would require including the full space of Uni- code symbols in order to model all Unicode strings. This would result in a base vocabulary of over 130,000 before any multi-symbol tokens are added. This is prohibitively large compared to the 32,000 to 64,000 token vocabularies often used with BPE. In contrast, a byte-level version of BPE only requires a base vocabulary of size 256. However, directly applying BPE to the byte sequence results in sub- optimal merges due to BPE using a greedy frequency based heuristic for building the token vocabulary. We observed BPE including many versions of common words like dog since they occur in many variations such as dog. dog! dog? . This results in a sub-optimal allocation of limited vocabulary slots and model capacity. To avoid this, we pre- vent BPE from merging across character categories for any byte sequence. We add an exception for spaces which sig- nificantly improves the compression efficiency while adding only minimal fragmentation of words across multiple vocab tokens.
```

In [3]:
tokens = list(map(str,unicode_text))
max_token_id = max(tokens)
print(tokens)

print("number of tokens: {} characters".format(len(tokens)))

['72', '97', '114', '114', '121', '32', '80', '111', '116', '116', '101', '114', '32', '105', '115', '32', '97', '32', '115', '101', '114', '105', '101', '115', '32', '111', '102', '32', '115', '101', '118', '101', '110', '32', '102', '97', '110', '116', '97', '115', '121', '32', '110', '111', '118', '101', '108', '115', '32', '119', '114', '105', '116', '116', '101', '110', '32', '98', '121', '32', '66', '114', '105', '116', '105', '115', '104', '32', '97', '117', '116', '104', '111', '114', '32', '74', '46', '32', '75', '46', '32', '82', '111', '119', '108', '105', '110', '103', '46', '32', '84', '104', '101', '32', '110', '111', '118', '101', '108', '115', '32', '99', '104', '114', '111', '110', '105', '99', '108', '101', '32', '116', '104', '101', '32', '108', '105', '118', '101', '115', '32', '111', '102', '32', '97', '32', '121', '111', '117', '110', '103', '32', '119', '105', '122', '97', '114', '100', '44', '32', '72', '97', '114', '114', '121', '32', '80', '111', '116', '116',

In [4]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)
print(stats)
sorted_stats = sorted(((v,k) for k,v in stats.items()),reverse=True)
print(sorted_stats)

{('72', '97'): 255, ('97', '114'): 1477, ('114', '114'): 337, ('114', '121'): 461, ('121', '32'): 1295, ('32', '80'): 609, ('80', '111'): 163, ('111', '116'): 468, ('116', '116'): 297, ('116', '101'): 1295, ('101', '114'): 2045, ('114', '32'): 1294, ('32', '105'): 1315, ('105', '115'): 902, ('115', '32'): 3018, ('32', '97'): 2658, ('97', '32'): 624, ('32', '115'): 1516, ('115', '101'): 992, ('114', '105'): 869, ('105', '101'): 473, ('101', '115'): 1305, ('32', '111'): 1284, ('111', '102'): 727, ('102', '32'): 688, ('101', '118'): 284, ('118', '101'): 724, ('101', '110'): 1216, ('110', '32'): 2216, ('32', '102'): 888, ('102', '97'): 98, ('97', '110'): 1803, ('110', '116'): 840, ('116', '97'): 477, ('97', '115'): 1008, ('115', '121'): 65, ('32', '110'): 322, ('110', '111'): 235, ('111', '118'): 201, ('101', '108'): 632, ('108', '115'): 175, ('32', '119'): 984, ('119', '114'): 58, ('105', '116'): 976, ('32', '98'): 868, ('98', '121'): 170, ('32', '66'): 245, ('66', '114'): 47, ('116', '10

In [5]:
# what does pair of unicodes that most appeared represent? 
print(sorted_stats[0])

print(f"Most appeared pair of unicodes that appeard are {(sorted_stats[0][1][0])} {(sorted_stats[0][1][1])}. That appeared {sorted_stats[0][0]} times")

for sorted_stat in sorted_stats:
    print(f"{(sorted_stat[1][0])}{(sorted_stat[1][1])}: {sorted_stat[0]} times")

(3678, ('101', '32'))
Most appeared pair of unicodes that appeard are 101 32. That appeared 3678 times
10132: 3678 times
11532: 3018 times
116104: 2704 times
3297: 2658 times
32116: 2658 times
104101: 2447 times
10032: 2318 times
105110: 2236 times
11032: 2216 times
101114: 2045 times
111110: 2027 times
97110: 1803 times
4432: 1656 times
114101: 1592 times
32115: 1516 times
97114: 1477 times
11632: 1389 times
32105: 1315 times
101115: 1305 times
12132: 1295 times
116101: 1295 times
11432: 1294 times
32111: 1284 times
110100: 1267 times
97116: 1232 times
101110: 1216 times
111114: 1182 times
101100: 1170 times
110103: 1152 times
116105: 1086 times
97115: 1008 times
115101: 992 times
32119: 984 times
105116: 976 times
97108: 920 times
108101: 903 times
105115: 902 times
32102: 888 times
114105: 869 times
3298: 868 times
3299: 867 times
4632: 853 times
10332: 851 times
110116: 840 times
115116: 837 times
116111: 817 times
11132: 797 times
10197: 754 times
111102: 727 times
100101: 725 tim

In [6]:
def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i+=2
        else: 
            newids.append(ids[i])
            i +=1
    return newids


In [7]:
# vocab_size = 1000
# num_merges = vocab_size - max_token_id
# ids = list(tokens)

# merges = {} 
# for i in range(num_merges):
#     stats = get_stats(ids)
#     pair = max(stats, key=stats.get)
#     idx = 256 + i
#     print(f"merging {pair} into a new token {idx}")
#     ids = merge(ids, pair, idx)
#     merges[pair] = idx

In [8]:
def get_vocab_dict(merges:dict) -> dict: 
    vocab_dict = {idx: bytes([idx]) for idx in range(256)}
    print(f"merges: {merges}")
    for (p0, p1), idx in merges.items():
        if ((p0 in vocab_dict) and (p1 in vocab_dict)):
            vocab_dict[idx] = vocab_dict[p0] + vocab_dict[p1]
    return vocab_dict

def decode(ids:list, vocab_dict: dict):
    tokens = b"".join(vocab_dict[idx] for idx in ids)
    text = tokens.decode("utf-8", errors='replace')
    return text        

# vocab_dict = get_vocab_dict(merges)

# print(decode([128], vocab_dict))

    

In [9]:
# print(merges)

In [10]:
def encode(text, merges: dict):
    tokens = list(text.encode("utf-8"))
    print(f"encode len tokens: {len(tokens)}")
    while True and (len(tokens) >= 2):
        stats = get_stats(tokens)
        print(f"encode stats: {stats}")
        pair = min(stats, key=lambda pair: merges.get(pair, float("inf"))) # if it cannot be merged, infintiy will be used
        
        # nothing else cannot be merged 
        if pair not in merges:
            break
        
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    
    return tokens
        

In [11]:
# ! conda install -y regex

In [12]:
import regex    

regex_pat_str = "|".join(
        [
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""\p{N}{1,3}""",
            r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""",
            r"""\s*[\r\n]+""",
            r"""\s+(?!\S)""",
            r"""\s+""",
        ]
    )

print(regex.findall(regex_pat_str, "Hello World!"))

['Hello', ' World', '!']


In [13]:
import regex    

# regex pattern for gpt4o
regex_pat_str = "|".join(
        [
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""\p{N}{1,3}""",
            r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""",
            r"""\s*[\r\n]+""",
            r"""\s+(?!\S)""",
            r"""\s+""",
        ]
    )

merges = {}

def train_tokens(text):
    unicode_text = text.encode("utf-8")
    all_tokens = list(map(int,unicode_text))
    max_token_id = max(all_tokens)
    idx =  max_token_id
    print(f"max_token_id: {max_token_id}")
    
    regex_match_tokens = regex.findall(regex_pat_str, text)
    print(regex_match_tokens)
    num = 0
    for regex_match_token in regex_match_tokens:
        if num == 30:
            break
        print("===================================================================")
        print(regex_match_token)
        unicode_text = regex_match_token.encode("utf-8")
        print("number of unicode characters: {} characters".format(len(unicode_text)))
        
        tokens = list(map(int,unicode_text))
        print(tokens)
        print("number of tokens: {} tokens".format(len(tokens)))
        
        # if len(tokens) == 1: 
        #     continue
        
        # stats = get_stats(tokens)
                
        # print(stats)
        # sorted_stats = sorted(((v,k) for k,v in stats.items()),reverse=True)
        # print(f"sorted_stats: {sorted_stats}")
        
        # print(f"Most appeared pair of unicodes that appeard are {chr(sorted_stats[0][1][0])} {chr(sorted_stats[0][1][1])}. That appeared {sorted_stats[0][0]} times")

        # for sorted_stat in sorted_stats:
        #     print(f"{chr(sorted_stat[1][0])}{chr(sorted_stat[1][1])}: {sorted_stat[0]} times")
            
            
        num_merges = 100
        ids = list(tokens)
        print(ids)

        for i in range(num_merges):
            stats:dict = get_stats(ids)

            print(f"stats = {stats}")
            if (len(stats) >= 1):
                pair:tuple = max(stats, key=stats.get)
                idx +=1
                sorted_stats = sorted(((v,k) for k,v in stats.items()),reverse=True)
                print(f"sorted_stats: {sorted_stats}")
                print(f"Most appeared pair of unicodes that appeard are {chr(sorted_stats[0][1][0])} {chr(sorted_stats[0][1][1])}. That appeared {sorted_stats[0][0]} times")    
                print(f"merging {pair} into a new token {idx}")
                ids:list = merge(ids, pair, idx)
                print(f"ids:{ids}")
                merges[pair] = idx
            vocab_dict = get_vocab_dict(merges)
            if (idx in vocab_dict):
                print(f"idx: {idx} => {decode([idx], vocab_dict)}")
            if len(stats) == 0:
                break
        
        print("===================================================================")
        num += 1

train_tokens(full_content)

max_token_id: 226
Harry
number of unicode characters: 5 characters
[72, 97, 114, 114, 121]
number of tokens: 5 tokens
[72, 97, 114, 114, 121]
stats = {(72, 97): 1, (97, 114): 1, (114, 114): 1, (114, 121): 1}
sorted_stats: [(1, (114, 121)), (1, (114, 114)), (1, (97, 114)), (1, (72, 97))]
Most appeared pair of unicodes that appeard are r y. That appeared 1 times
merging (72, 97) into a new token 227
ids:[227, 114, 114, 121]
merges: {(72, 97): 227}
idx: 227 => Ha
stats = {(227, 114): 1, (114, 114): 1, (114, 121): 1}
sorted_stats: [(1, (227, 114)), (1, (114, 121)), (1, (114, 114))]
Most appeared pair of unicodes that appeard are ã r. That appeared 1 times
merging (227, 114) into a new token 228
ids:[228, 114, 121]
merges: {(72, 97): 227, (227, 114): 228}
idx: 228 => Har
stats = {(228, 114): 1, (114, 121): 1}
sorted_stats: [(1, (228, 114)), (1, (114, 121))]
Most appeared pair of unicodes that appeard are ä r. That appeared 1 times
merging (228, 114) into a new token 229
ids:[229, 121]
merge

KeyError: 260

In [None]:
encoded_tokens = encode("hello world", merges)
print(encoded_tokens)
encoded_tokens =encode("Harry Potter is a series of seven fantasy novels written by British author J.K. Rowling.", merges)
print(encoded_tokens)

encode len tokens: 11
encode stats: {(104, 101): 1, (101, 108): 1, (108, 108): 1, (108, 111): 1, (111, 32): 1, (32, 119): 1, (119, 111): 1, (111, 114): 1, (114, 108): 1, (108, 100): 1}
encode stats: {(104, 101): 1, (101, 108): 1, (108, 108): 1, (108, 111): 1, (111, 266): 1, (266, 111): 1, (111, 114): 1, (114, 108): 1, (108, 100): 1}
[104, 101, 108, 108, 111, 266, 111, 114, 108, 100]
encode len tokens: 88
encode stats: {(72, 97): 1, (97, 114): 1, (114, 114): 1, (114, 121): 1, (121, 32): 3, (32, 80): 1, (80, 111): 1, (111, 116): 1, (116, 116): 2, (116, 101): 2, (101, 114): 2, (114, 32): 2, (32, 105): 1, (105, 115): 2, (115, 32): 3, (32, 97): 2, (97, 32): 1, (32, 115): 2, (115, 101): 2, (114, 105): 3, (105, 101): 1, (101, 115): 1, (32, 111): 1, (111, 102): 1, (102, 32): 1, (101, 118): 1, (118, 101): 2, (101, 110): 2, (110, 32): 2, (32, 102): 1, (102, 97): 1, (97, 110): 1, (110, 116): 1, (116, 97): 1, (97, 115): 1, (115, 121): 1, (32, 110): 1, (110, 111): 1, (111, 118): 1, (101, 108): 1, (

In [None]:
print(merges)

{(72, 97): 227, (227, 114): 228, (228, 114): 229, (229, 121): 230, (32, 80): 231, (231, 111): 232, (232, 116): 233, (233, 116): 234, (234, 101): 235, (235, 114): 236, (32, 105): 237, (237, 115): 238, (32, 97): 282, (32, 115): 248, (240, 101): 241, (241, 114): 242, (242, 105): 243, (243, 101): 244, (244, 115): 245, (32, 111): 246, (246, 102): 247, (248, 101): 249, (249, 118): 250, (250, 101): 251, (251, 110): 252, (32, 102): 253, (253, 97): 254, (254, 110): 255, (255, 116): 256, (256, 97): 257, (257, 115): 258, (258, 121): 259, (32, 110): 260, (260, 111): 261, (261, 118): 262, (262, 101): 263, (263, 108): 264, (264, 115): 265, (32, 119): 266, (266, 114): 267, (267, 105): 268, (268, 116): 269, (269, 116): 270, (270, 101): 271, (271, 110): 272, (32, 98): 273, (273, 121): 274, (32, 66): 275, (275, 114): 276, (276, 105): 277, (277, 116): 278, (278, 105): 279, (279, 115): 280, (280, 104): 281, (282, 117): 283, (283, 116): 284, (284, 104): 285, (285, 111): 286, (286, 114): 287, (32, 74): 288,

In [None]:
vocab_dict = get_vocab_dict(merges)
print(vocab_dict)

vocab_dict_str = {}
for vocab in range(len(vocab_dict)):
    print(decode([vocab], vocab_dict))
    vocab_dict_str[vocab] = decode([vocab], vocab_dict)
    
print(encode(" Python",merges))

import json
dir = "./vocab_dict.json"
with open(dir, mode="wt", encoding="utf-8") as f:
	json.dump(vocab_dict_str, f, ensure_ascii=False, indent=2)

merges: {(72, 97): 227, (227, 114): 228, (228, 114): 229, (229, 121): 230, (32, 80): 231, (231, 111): 232, (232, 116): 233, (233, 116): 234, (234, 101): 235, (235, 114): 236, (32, 105): 237, (237, 115): 238, (32, 97): 282, (32, 115): 248, (240, 101): 241, (241, 114): 242, (242, 105): 243, (243, 101): 244, (244, 115): 245, (32, 111): 246, (246, 102): 247, (248, 101): 249, (249, 118): 250, (250, 101): 251, (251, 110): 252, (32, 102): 253, (253, 97): 254, (254, 110): 255, (255, 116): 256, (256, 97): 257, (257, 115): 258, (258, 121): 259, (32, 110): 260, (260, 111): 261, (261, 118): 262, (262, 101): 263, (263, 108): 264, (264, 115): 265, (32, 119): 266, (266, 114): 267, (267, 105): 268, (268, 116): 269, (269, 116): 270, (270, 101): 271, (271, 110): 272, (32, 98): 273, (273, 121): 274, (32, 66): 275, (275, 114): 276, (276, 105): 277, (277, 116): 278, (278, 105): 279, (279, 115): 280, (280, 104): 281, (282, 117): 283, (283, 116): 284, (284, 104): 285, (285, 111): 286, (286, 114): 287, (32, 7