In [1]:
### CONGREGATING DATA ###
import os
all_files = ""
directory = "mini_LDR"

#for each file in directory, append its contents to all_files string
for filename in os.listdir(directory):
    if filename.endswith(".ldr"):
        file_path = os.path.join(directory, filename)
        with open(file_path, 'r') as file:
            all_files += file.read()

In [2]:
### PRE-TOKENISATION USING GPT4 ###
import regex as regx
gpt4_pattern = regx.compile(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""")
#https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py

all_sub_units = regx.findall(gpt4_pattern, all_files)

In [9]:
### CONVERT EACH CHAR TO UTF-8 RAW BYTES AND INITALISE TOKEN SET ###
subUnits = []

#convert each subunit into raw bytes, then their corresponding ints
for subUnit in all_sub_units:
    rawBytes = subUnit.encode("utf-8")
    int_subUnits = list(map(int, rawBytes))
    subUnits.append(int_subUnits)

In [10]:
### FUNCS: MOST FREQUENT PAIR & MERGE  ###

#finding the frequency of adjacent pairs within each subUnit
def get_counts(ids):
    counts = {}
    for subUnit in ids:
        for pair in zip(subUnit, subUnit[1:]):     #pythonic way to iterate consec elements
            counts[pair] = counts.get(pair, 0) + 1 #if pair DNE val defaulted to 0 
    return counts

#replacing all occurances of pair, within ids (token_subUnits), with new token idx
def merge_tokens(ids, pair, idx):
    newIDs = []
    for subUnit in ids:
        newUnit, i = [], 0
        while i < len(subUnit):
            if i < len(subUnit)-1 and subUnit[i] == pair[0] and subUnit[i+1] == pair[1]:
                newUnit.append(idx)
                i += 2
            else:
                newUnit.append(subUnit[i])
                i += 1
        newIDs.append(newUnit)
    return newIDs

In [11]:
### TRAINING THE TOKENIZER ###

#setting hyper param
vocab_size = 500
num_merges = vocab_size - 256
allSubUnits = subUnits

#PERFORMING MERGES -> find most common pair, then merge it
merges = {}
for i in range(num_merges):
    frequencies = get_counts(allSubUnits)
    topPair = max(frequencies, key=frequencies.get) #comparares on vals rather then keys
    newID = 256 + i
    allSubUnits = merge_tokens(allSubUnits, topPair, newID)
    merges[topPair] = newID
    #print statement to show progress
    if (i + 1) % 100 == 0:
        print(f"Processed {i + 1} merges")

Processed 100 merges
Processed 200 merges


In [12]:
#### FINALISING VOCABULARY ###
vocab = {}
#first 256 utf-8 tokens
for idx in range(256): 
    vocab[idx] = bytes([idx]) #bytes needs list arg
#using merges, add all new tokens to final vocab
for (t0, t1), idx in merges.items(): #items() makes map traversable & in order of insertion
    vocab[idx] = vocab[t0] + vocab[t1] #concatenating 'byte objects'

### TRAINING COMPLETE ###

In [17]:
### INFERENCE CODE -> USED IN ENCODE & DECODE ###
def get_counts_inf(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1 
    return counts

def merge_tokens_inf(ids, pair, idx):
    newIDs, i = [], 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newIDs.append(idx)
            i += 2
        else:
            newIDs.append(ids[i])
            i += 1
    return newIDs

In [20]:
### ENCODE & DECODE FUNCTIONS LDR <-> Tokens ###
def encode(text):
    tokens = list(text.encode("utf-8")) #raw bytes
    #loop till no more merges possible
    while len(tokens) >= 2:
        frequencies = get_counts_inf(tokens)
        #want byte pair (key) inside stats that has lowest indx in merges dict (lowest merged first)
        pair = min(frequencies, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break 
        idx = merges[pair]
        tokens = merge_tokens_inf(tokens, pair, idx)
    return tokens

def decode(tokens):
    #processing each token by decoding it with
    byteTokens = b"".join(vocab[token] for token in tokens)
    text = byteTokens.decode("utf-8", errors="replace")
    return text

success: True


In [25]:
### TESTING ###
before, after = 0, 0
for subUnit in subUnits: before += len(subUnit)
for subUnit in allSubUnits: after += len(subUnit)
print(f"Count before: {before}\nCount after: {after}")

print(f"\nBEFORE: {subUnits[:10]}")
print(f"\nAFTER: {allSubUnits[:10]}")

text = "0 !LEOCAD MODEL AUTHOR LEGO staff (unknown);"
theTokens = encode(text);
result = decode(theTokens)
print(f"\nEncode & Decode Success: {result == text}")

Count before: 25246
Count after: 17928

BEFORE: [[48], [32, 33], [76, 69, 79, 67, 65, 68], [32, 77, 79, 68, 69, 76], [32, 65, 85, 84, 72, 79, 82], [32, 76, 69, 71, 79], [32, 115, 116, 97, 102, 102], [32, 40], [117, 110, 107, 110, 111, 119, 110], [41, 59]]

AFTER: [[48], [290], [294], [325], [426], [455], [458], [382], [462], [384]]

Encode & Decode Success: True
