In [2]:
#converting chars to tokens using utr-8 ids
text = "hello world its me will and I'm a cool guy with lots to say, just kidding im actually very humble. The reason i said that is because i am trying to create a really long string with lots of words so i can train this tokeniser. pretty genius way of creating text dont you think? i'm literally just wrtiting out my inner dialogue. Ok im going to stop now because im feeling kind of crazy"
tokens = text.encode("utf-8") #raw bytes
# print(tokens)
tokens = list(map(int, tokens)) #bytes to ints
print("# tokens: ", len(tokens))
# print(tokens)

b"hello world its me will and I'm a cool guy with lots to say, just kidding im actually very humble. The reason i said that is because i am trying to create a really long string with lots of words so i can train this tokeniser. pretty genius way of creating text dont you think? i'm literally just wrtiting out my inner dialogue. Ok im going to stop now because im feeling kind of crazy"
# tokens:  384


In [35]:
#finding most common pairs
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): #pythonic way to iterate consec elements
        counts[pair] = counts.get(pair, 0) + 1 #if pair DNE val defaulted to 0 
    return counts

#counting paired tokens & then printing them in most frequent order
stats = get_stats(tokens)
# print( sorted( ((v,k) for k,v in stats.items()) , reverse=True) )

In [3]:
#in list of ids (ints), replace all consecutive occurances of pair(two chars) with new token idx(int)
def merge(ids, pair, idx):
    newIDs = []
    i = 0
    while i < len(ids):
        #if not at end & pair matches, replace it
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newIDs.append(idx)
            i += 2
        else:
            newIDs.append(ids[i])
            i += 1
    return newIDs

print( merge([5, 6, 6, 7, 9, 1], (6, 7), 99) )

[5, 6, 99, 9, 1]


In [37]:
#merging recurring bits
vocab_size = 256 + 100 #256 from utf8 then option for 20 merges
num_merges = vocab_size - 256
ids = list(tokens) #creating copy

merges = {}
#for num_merges, find most common pair, create new idx, update the output (ids) and store the merge in dict (merges)
for i in range(num_merges):
    #find most common pair
    stats = get_stats(ids)
    pair = max(stats, key=stats.get) #getting the current top pair
    #create new idx, update the output (ids)
    idx = 256 + i
    ids = merge(ids, pair, idx)
    #store curr merge in dict (merges)
    merges[pair] = idx

# print(f"before: {len(tokens)} after: {len(ids)} ratio: {len(tokens)/len(ids):.2f}")

In [38]:
#create hashmap of key(idx)-value(byte object) pair -> when printed outputs a char
vocab = {}
#first add all regular 256 utf-8 chars
for idx in range(256):
    vocab[idx] = bytes([idx]) 
#next add all merged tokens
for (p0, p1), idx in merges.items(): #runs in order of insertion
    vocab[idx] = vocab[p0] + vocab[p1] #concatenates 'byte objects'

# "training complete" -> merges performed and vocab created (partially using merges)

In [39]:
#given tokens (list of ints), return python string
def decode(ids):
    #concatenating all bytes together using all indexes(idx) in ids (list of byte objects)
    tokens = b"".join(vocab[idx] for idx in ids)
    # print(tokens)
    text = tokens.decode("utf-8", errors="replace")
    return text

In [40]:
#given string, return list of ints (tokens)
def encode(text):
    tokens = list(text.encode("utf-8")) #converting text to raw bytes
    #loop will repeat until no more merges can be performed (if statement will exit
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        #want byte pair (key) inside stats that has lowest indx in merges dict (since low level stuff merged first)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        #if pair DNE nothing else can be merged
        if pair not in merges:
            break 
        #otherwise update tokens 
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens

In [44]:
print(decode(encode("hello world")))
print(text == decode(encode(text)))

hello world
True


In [29]:
# FINISHED BASIC IMPLEMENTATION OF BYTE PAIR TOKENISER