In [1]:
with open("input.txt", 'r', encoding='utf-8') as file:
    text = file.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"We have {vocab_size} unique characters in our vocabulary")

We have 65 unique characters in our vocabulary


In [2]:
utf_ids = text.encode('utf-8')
utf_ids = list(map(int, utf_ids))

In [4]:
def get_pair_counts_dict(utf_ids):
    counts = {}
    for pair in zip(utf_ids, utf_ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

In [5]:
def merge_and_replace_pairs(utf_ids, pair, new_id):
    new_utf_ids = []
    i = 0
    while i < len(utf_ids):
        if i < len(utf_ids) - 1 and utf_ids[i] == pair[0] and utf_ids[i+1] == pair[1]:
            new_utf_ids.append(new_id)
            i += 2
        else:
            new_utf_ids.append(utf_ids[i])
            i += 1
    return new_utf_ids

In [6]:
num_merges = 128
new_id = 256

new_utf_ids = utf_ids.copy()
merges_map = {}

for i in range(num_merges):
    pair_counts_dict = get_pair_counts_dict(new_utf_ids)
    max_count_pair = max(pair_counts_dict, key=pair_counts_dict.get)
    new_id = 256 + i
    new_utf_ids = merge_and_replace_pairs(new_utf_ids, max_count_pair, new_id)
    merges_map[max_count_pair] = new_id

In [6]:
byte_map = {old_id: bytes([old_id]) for old_id in range(256)}

In [7]:
for ((id_1, id_2), new_id) in merges_map.items():
    byte_map[new_id] = byte_map[id_1] + byte_map[id_2]

In [8]:
def decode_ids_to_string(ids_list):
    return b"".join(byte_map[i] for i in ids_list).decode('utf-8', errors='replace')

In [9]:
def encode_string_to_ids(text):
    utf_ids = list(text.encode('utf-8'))    
    while len(utf_ids) >= 2:
        pair_counts_dict = get_pair_counts_dict(utf_ids)
        pair = min(pair_counts_dict, key=lambda p: merges_map.get(p, float('inf')))
        if pair not in merges_map:
            break
        utf_id = merges_map[pair]
        utf_ids = merge_and_replace_pairs(utf_ids, pair, utf_id)
    return utf_ids

In [24]:
x = encode_string_to_ids("lower")

In [26]:
[byte_map[i].decode('utf-8') for i in x]

['l', 'ow', 'er']

In [11]:
import regex as re

gpt2regex = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

In [27]:
import os
import json

with open('encoder.json', 'r') as file:
    gpt2_byte_map = json.load(file)
    
with open('vocab.bpe', 'r', encoding='utf-8') as file:
    bpe_data = file.read()
gpt2_merges_map = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]