Byte Pair Encoding

This notebook will walk you through Byte Pair Encoding (BPE). We use this code to generate a JSON file containing byte pair encodings derived from the given text.

I suggest checking out Andrej Karpathy's video on youtube (https://youtu.be/zduSFxRajkE). The video helped a lot in developing this code.

In [11]:
# text from wikipedia: https://en.wikipedia.org/wiki/Text_processing
text = """An editor essentially invokes an input stream and directs it to the text processing environment, which is either a command shell or a text editor. The resulting output is applicable to further text processing, the final result of which is comparable to a single application of an algorithm applied once by a more sophisticated and structured computer program. """
vocab = sorted(list(set(text)))

print(text)
print("length:", len(text))
print("-----")
print("vocabulary:", "".join(vocab))
print("length:", len(vocab))

An editor essentially invokes an input stream and directs it to the text processing environment, which is either a command shell or a text editor. The resulting output is applicable to further text processing, the final result of which is comparable to a single application of an algorithm applied once by a more sophisticated and structured computer program. 
length: 360
-----
vocabulary:  ,.ATabcdefghiklmnoprstuvwxy
length: 28


In [12]:
token_dict = {}
token_counter = 0

for char in vocab:
    token_dict[token_counter] = char
    token_counter += 1

print(token_dict)

{0: ' ', 1: ',', 2: '.', 3: 'A', 4: 'T', 5: 'a', 6: 'b', 7: 'c', 8: 'd', 9: 'e', 10: 'f', 11: 'g', 12: 'h', 13: 'i', 14: 'k', 15: 'l', 16: 'm', 17: 'n', 18: 'o', 19: 'p', 20: 'r', 21: 's', 22: 't', 23: 'u', 24: 'v', 25: 'w', 26: 'x', 27: 'y'}


In [13]:
def encode(text):
    encoded = []

    for char in text:
        encoded.append(next((token_id for token_id, token in token_dict.items() if token == char), None))
    
    return encoded

def decode(encoded):
    decoded = ""

    for code in encoded:
        decoded += token_dict[code]

    return decoded

tokens = encode(text)
print(tokens)
print("length:", len(tokens))

[3, 17, 0, 9, 8, 13, 22, 18, 20, 0, 9, 21, 21, 9, 17, 22, 13, 5, 15, 15, 27, 0, 13, 17, 24, 18, 14, 9, 21, 0, 5, 17, 0, 13, 17, 19, 23, 22, 0, 21, 22, 20, 9, 5, 16, 0, 5, 17, 8, 0, 8, 13, 20, 9, 7, 22, 21, 0, 13, 22, 0, 22, 18, 0, 22, 12, 9, 0, 22, 9, 26, 22, 0, 19, 20, 18, 7, 9, 21, 21, 13, 17, 11, 0, 9, 17, 24, 13, 20, 18, 17, 16, 9, 17, 22, 1, 0, 25, 12, 13, 7, 12, 0, 13, 21, 0, 9, 13, 22, 12, 9, 20, 0, 5, 0, 7, 18, 16, 16, 5, 17, 8, 0, 21, 12, 9, 15, 15, 0, 18, 20, 0, 5, 0, 22, 9, 26, 22, 0, 9, 8, 13, 22, 18, 20, 2, 0, 4, 12, 9, 0, 20, 9, 21, 23, 15, 22, 13, 17, 11, 0, 18, 23, 22, 19, 23, 22, 0, 13, 21, 0, 5, 19, 19, 15, 13, 7, 5, 6, 15, 9, 0, 22, 18, 0, 10, 23, 20, 22, 12, 9, 20, 0, 22, 9, 26, 22, 0, 19, 20, 18, 7, 9, 21, 21, 13, 17, 11, 1, 0, 22, 12, 9, 0, 10, 13, 17, 5, 15, 0, 20, 9, 21, 23, 15, 22, 0, 18, 10, 0, 25, 12, 13, 7, 12, 0, 13, 21, 0, 7, 18, 16, 19, 5, 20, 5, 6, 15, 9, 0, 22, 18, 0, 5, 0, 21, 13, 17, 11, 15, 9, 0, 5, 19, 19, 15, 13, 7, 5, 22, 13, 18, 17, 0, 18, 10, 0,

In [14]:
message = decode(tokens)
print(message)

An editor essentially invokes an input stream and directs it to the text processing environment, which is either a command shell or a text editor. The resulting output is applicable to further text processing, the final result of which is comparable to a single application of an algorithm applied once by a more sophisticated and structured computer program. 


In [15]:
vocab_size = len(token_dict.items())
print("Current vocab size:", vocab_size)
print("token_id range (%d, %d)"%(0, vocab_size - 1))

Current vocab size: 28
token_id range (0, 27)


In [16]:
def get_frequencies(tokens):
    frequencies = {}

    for i in range(len(tokens) - 1):
        try:
            frequencies[(tokens[i], tokens[i + 1])] += 1
        except:
            frequencies[(tokens[i], tokens[i + 1])] = 1

    return frequencies


frequencies = get_frequencies(tokens)
print(frequencies)
print("-----")

highest_pair = max(frequencies, key=frequencies.get)
print("Highest pair", highest_pair, "-", frequencies[highest_pair], "occurences")

{(3, 17): 1, (17, 0): 4, (0, 9): 5, (9, 8): 5, (8, 13): 3, (13, 22): 5, (22, 18): 5, (18, 20): 5, (20, 0): 5, (9, 21): 6, (21, 21): 3, (21, 9): 1, (9, 17): 3, (17, 22): 2, (22, 13): 4, (13, 5): 1, (5, 15): 3, (15, 15): 2, (15, 27): 1, (27, 0): 2, (0, 13): 6, (13, 17): 7, (17, 24): 2, (24, 18): 1, (18, 14): 1, (14, 9): 1, (21, 0): 5, (0, 5): 12, (5, 17): 5, (17, 19): 1, (19, 23): 3, (23, 22): 4, (22, 0): 7, (0, 21): 5, (21, 22): 3, (22, 20): 2, (20, 9): 6, (9, 5): 1, (5, 16): 2, (16, 0): 2, (17, 8): 3, (8, 0): 6, (0, 8): 1, (13, 20): 2, (9, 7): 1, (7, 22): 2, (22, 21): 1, (0, 22): 8, (18, 0): 3, (22, 12): 5, (12, 9): 6, (9, 0): 8, (22, 9): 5, (9, 26): 3, (26, 22): 3, (0, 19): 3, (19, 20): 3, (20, 18): 4, (18, 7): 2, (7, 9): 3, (21, 13): 3, (17, 11): 4, (11, 0): 2, (24, 13): 1, (18, 17): 3, (17, 16): 1, (16, 9): 1, (22, 1): 1, (1, 0): 2, (0, 25): 2, (25, 12): 2, (12, 13): 3, (13, 7): 5, (7, 12): 2, (12, 0): 2, (13, 21): 4, (9, 13): 1, (9, 20): 3, (5, 0): 4, (0, 7): 3, (7, 18): 3, (18, 16

In [17]:
def merge_pairs(tokens, pair, token_id):
    i = 0
    new_tokens = []

    while i < len(tokens) - 1:  # prevent out of range index
        if tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
            new_tokens.append(token_id)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    new_tokens.append(tokens[-1])  # add the last indexwhile token_counter < 30

    return new_tokens


tokens = merge_pairs(tokens, highest_pair, token_counter)
token_dict[token_counter] = decode(highest_pair) # add new token to the dictionary
token_counter += 1

print(tokens)
print("length:", len(tokens))

print("Notice that the length of the total tokens decreased")
print("It also means that the vocabulary size has now increased")

[3, 17, 0, 9, 8, 13, 22, 18, 20, 0, 9, 21, 21, 9, 17, 22, 13, 5, 15, 15, 27, 0, 13, 17, 24, 18, 14, 9, 21, 28, 17, 0, 13, 17, 19, 23, 22, 0, 21, 22, 20, 9, 5, 16, 28, 17, 8, 0, 8, 13, 20, 9, 7, 22, 21, 0, 13, 22, 0, 22, 18, 0, 22, 12, 9, 0, 22, 9, 26, 22, 0, 19, 20, 18, 7, 9, 21, 21, 13, 17, 11, 0, 9, 17, 24, 13, 20, 18, 17, 16, 9, 17, 22, 1, 0, 25, 12, 13, 7, 12, 0, 13, 21, 0, 9, 13, 22, 12, 9, 20, 28, 0, 7, 18, 16, 16, 5, 17, 8, 0, 21, 12, 9, 15, 15, 0, 18, 20, 28, 0, 22, 9, 26, 22, 0, 9, 8, 13, 22, 18, 20, 2, 0, 4, 12, 9, 0, 20, 9, 21, 23, 15, 22, 13, 17, 11, 0, 18, 23, 22, 19, 23, 22, 0, 13, 21, 28, 19, 19, 15, 13, 7, 5, 6, 15, 9, 0, 22, 18, 0, 10, 23, 20, 22, 12, 9, 20, 0, 22, 9, 26, 22, 0, 19, 20, 18, 7, 9, 21, 21, 13, 17, 11, 1, 0, 22, 12, 9, 0, 10, 13, 17, 5, 15, 0, 20, 9, 21, 23, 15, 22, 0, 18, 10, 0, 25, 12, 13, 7, 12, 0, 13, 21, 0, 7, 18, 16, 19, 5, 20, 5, 6, 15, 9, 0, 22, 18, 28, 0, 21, 13, 17, 11, 15, 9, 28, 19, 19, 15, 13, 7, 5, 22, 13, 18, 17, 0, 18, 10, 28, 17, 28, 15, 

In [18]:
print(token_dict)

{0: ' ', 1: ',', 2: '.', 3: 'A', 4: 'T', 5: 'a', 6: 'b', 7: 'c', 8: 'd', 9: 'e', 10: 'f', 11: 'g', 12: 'h', 13: 'i', 14: 'k', 15: 'l', 16: 'm', 17: 'n', 18: 'o', 19: 'p', 20: 'r', 21: 's', 22: 't', 23: 'u', 24: 'v', 25: 'w', 26: 'x', 27: 'y', 28: ' a'}


In [19]:
while token_counter < 32:  # continue merging until a desired vocab size is achieved
    frequencies = get_frequencies(tokens)
    highest_pair = max(frequencies, key=frequencies.get)

    if frequencies[highest_pair] == 1:  # break if no more paires can be merged
        break

    tokens = merge_pairs(tokens, highest_pair, token_counter)
    token_dict[token_counter] = decode(highest_pair)  # add new token to the dictionary
    token_counter += 1

print(token_dict)

{0: ' ', 1: ',', 2: '.', 3: 'A', 4: 'T', 5: 'a', 6: 'b', 7: 'c', 8: 'd', 9: 'e', 10: 'f', 11: 'g', 12: 'h', 13: 'i', 14: 'k', 15: 'l', 16: 'm', 17: 'n', 18: 'o', 19: 'p', 20: 'r', 21: 's', 22: 't', 23: 'u', 24: 'v', 25: 'w', 26: 'x', 27: 'y', 28: ' a', 29: ' t', 30: 'in', 31: 'es'}


In [20]:
frequencies = get_frequencies(tokens)
highest_pair = max(frequencies, key=frequencies.get)

# alternatively we can run the loop until there are no more repeating pairs
while frequencies[highest_pair] != 1:
    tokens = merge_pairs(tokens, highest_pair, token_counter)
    token_dict[token_counter] = decode(highest_pair)  # add new token to the dictionary
    token_counter += 1

    frequencies = get_frequencies(tokens)
    highest_pair = max(frequencies, key=frequencies.get)

print(token_dict)

{0: ' ', 1: ',', 2: '.', 3: 'A', 4: 'T', 5: 'a', 6: 'b', 7: 'c', 8: 'd', 9: 'e', 10: 'f', 11: 'g', 12: 'h', 13: 'i', 14: 'k', 15: 'l', 16: 'm', 17: 'n', 18: 'o', 19: 'p', 20: 'r', 21: 's', 22: 't', 23: 'u', 24: 'v', 25: 'w', 26: 'x', 27: 'y', 28: ' a', 29: ' t', 30: 'in', 31: 'es', 32: 't ', 33: 'he', 34: 'ed', 35: 'it', 36: 'or', 37: 'ic', 38: ' an', 39: 'ro', 40: 'ing', 41: 'is', 42: ' s', 43: 'ess', 44: 'en', 45: 'pu', 46: ' to', 47: ' te', 48: ' tex', 49: ' text ', 50: 'pro', 51: ' c', 52: ' co', 53: ' com', 54: ' o', 55: ' ap', 56: ' app', 57: ' appl', 58: 'ica', 59: 'le', 60: 'edit', 61: 'editor', 62: 'ent', 63: 'al', 64: ' in', 65: 'put ', 66: 'tr', 67: 'am', 68: ' and', 69: 'ct', 70: ' the', 71: ' text pro', 72: ' text proc', 73: ' text process', 74: ' text processing', 75: ' w', 76: ' wh', 77: ' whic', 78: ' which', 79: ' which ', 80: ' which is', 81: 'her', 82: '. ', 83: ' r', 84: ' res', 85: ' resu', 86: ' resul', 87: ' applica', 88: 'ble', 89: 'ble to', 90: ' f', 91: 'ur'}


In [22]:
print(tokens)
print("length:", len(tokens))

[3, 17, 0, 61, 0, 43, 62, 13, 63, 15, 27, 64, 24, 18, 14, 31, 38, 64, 65, 21, 66, 9, 67, 68, 0, 8, 13, 20, 9, 69, 21, 0, 35, 46, 70, 74, 0, 44, 24, 13, 39, 17, 16, 62, 1, 80, 0, 9, 35, 81, 28, 53, 16, 5, 17, 8, 42, 33, 15, 15, 0, 36, 28, 49, 61, 82, 4, 33, 86, 22, 40, 54, 23, 22, 65, 41, 87, 89, 90, 91, 22, 81, 74, 1, 70, 90, 30, 63, 86, 32, 18, 10, 80, 53, 19, 5, 20, 5, 89, 28, 42, 40, 59, 87, 22, 13, 18, 17, 54, 10, 38, 28, 15, 11, 36, 35, 12, 16, 57, 13, 34, 54, 17, 7, 9, 0, 6, 27, 28, 0, 16, 36, 9, 42, 18, 19, 12, 41, 22, 58, 22, 34, 68, 42, 66, 23, 69, 91, 34, 53, 45, 22, 9, 20, 0, 50, 11, 20, 67, 82, 0]
length: 161


As you can see, this will result in a highly compressed corpus but will lead to a bigger vocabulary. Hence, there are a lot of adjustments that could be made here.

Now go through BPE.py to get a look at the whole thing.