In [95]:
### CONGREGATING DATA ###
import os
all_files = ""
directory = "mini_LDR"

#for each file in directory, append its contents to all_files string
for filename in os.listdir(directory):
    if filename.endswith(".ldr"):
        file_path = os.path.join(directory, filename)
        with open(file_path, 'r') as file:
            all_files += file.read()

In [96]:
### PRE-TOKENISATION USING GPT4 ###
import regex as regx
gpt4_pattern = regx.compile(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""")
#https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py

all_sub_units = regx.findall(gpt4_pattern, all_files)

In [97]:
### CONVERT EACH CHAR TO UTF-8 RAW BYTES AND INITALISE TOKEN SET ###
subUnits = []

#convert each subunit into raw bytes, then their corresponding ints
for subUnit in all_sub_units:
    rawBytes = subUnit.encode("utf-8")
    int_subUnits = list(map(int, rawBytes))
    subUnits.append(int_subUnits)

print(subUnits[:10])

[[48], [32, 33], [76, 69, 79, 67, 65, 68], [32, 77, 79, 68, 69, 76], [32, 65, 85, 84, 72, 79, 82], [32, 76, 69, 71, 79], [32, 115, 116, 97, 102, 102], [32, 40], [117, 110, 107, 110, 111, 119, 110], [41, 59]]


In [98]:
### FUNCS: MOST FREQUENT PAIR & MERGE  ###

#finding the frequency of adjacent pairs within each subUnit
def get_counts(ids):
    counts = {}
    for subUnit in ids:
        for pair in zip(subUnit, subUnit[1:]):     #pythonic way to iterate consec elements
            counts[pair] = counts.get(pair, 0) + 1 #if pair DNE val defaulted to 0 
    return counts

#replacing all occurances of pair, within ids (token_subUnits), with new token idx
def merge_tokens(ids, pair, idx):
    newIDs = []
    for subUnit in ids:
        newUnit, i = [], 0
        while i < len(subUnit):
            if i < len(subUnit)-1 and subUnit[i] == pair[0] and subUnit[i+1] == pair[1]:
                newUnit.append(idx)
                i += 2
            else:
                newUnit.append(subUnit[i])
                i += 1
        newIDs.append(newUnit)
    return newIDs

In [99]:
### TRAINING THE TOKENIZER ###

#setting hyper param
vocab_size = 500
num_merges = vocab_size - 256
allSubUnits = subUnits

#PERFORMING MERGES -> find most common pair, then merge it
merges = {}
for i in range(num_merges):
    frequencies = get_counts(allSubUnits)
    topPair = max(frequencies, key=frequencies.get) #comparares on vals rather then keys
    newID = 256 + i
    allSubUnits = merge_tokens(allSubUnits, topPair, newID)
    merges[topPair] = newID
    #print statement to show progress
    if (i + 1) % 100 == 0:
        print(f"Processed {i + 1} merges")

Processed 100 merges
Processed 200 merges


In [100]:
#FINALISING VOCABULARY -> using merges, add all new tokens to final vocab
vocab = {}
for idx in range(256): 
    vocab[idx] = bytes([idx]) #bytes needs list arg
for (t0, t1), idx in merges.items(): #items() makes map traversable & in order of insertion
    vocab[idx] = vocab[t0] + vocab[t1] #concatenating 'byte objects'

###TRAINING COMPLETE###

In [None]:
### ENCODE & DECODE FUNCTIONS LDR <-> Tokens ###

def encode(text):
    tokens = list(text.encode("utf-8")) #raw bytes
    
    #loop will repeat until no more merges can be performed (if statement will exit
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        #want byte pair (key) inside stats that has lowest indx in merges dict (since low level stuff merged first)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        #if pair DNE nothing else can be merged
        if pair not in merges:
            break 
        #otherwise update tokens 
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens

def decode(ids):
    #concatenating all bytes together using all indexes(idx) in ids (list of byte objects)
    tokens = b"".join(vocab[idx] for idx in ids)
    # print(tokens)
    text = tokens.decode("utf-8", errors="replace")
    return text

In [87]:
### TESTING ###
before, after = 0, 0
for subUnit in subUnits: before += len(subUnit)
for subUnit in allSubUnits: after += len(subUnit)
print(f"Token count before: {before}\nToken count after: {after}")

print("\n\nFirst subUnits before and afterwards\n")
print(subUnits[:10])
print("\n")
print(allSubUnits[:10])

Token count before: 301976
Token count after: 210557


First subUnits before and afterwards

[[48], [32, 33], [76, 69, 79, 67, 65, 68], [32, 77, 79, 68, 69, 76], [32, 65, 85, 84, 72, 79, 82], [32, 76, 69, 71, 79], [32, 115, 116, 97, 102, 102], [32, 40], [117, 110, 107, 110, 111, 119, 110], [41, 59]]


[[48], [303], [307], [370], [509], [613], [595], [572], [636], [584]]
