<a href="https://colab.research.google.com/github/vcdim/colab/blob/main/bpe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BPE

In this notebook, we implement a simple BPE algorithm. The starting point will be a string, which is usually the input of an NLP task.

In [1]:
s = "hello, 你好！"
# each Chinese character takes 3 bytes.
# note: \x is the prefix for hex value in python.

s_bytes = s.encode('utf-8')

print(s_bytes)
print(len(s_bytes))

b'hello, \xe4\xbd\xa0\xe5\xa5\xbd\xef\xbc\x81'
16


For better modularization, we define several routines for converting between string, bytes, and integer arrays. Converting integer might is just useful for better visualization? It's probably necessary.

In [4]:
def string_to_bytes(s):
    return s.encode('utf-8')


def bytes_to_string(b):
    return b.decode('utf-8')


def string_to_ints(s):
    return [int(u) for u in string_to_bytes(s)]


def ints_to_string(s):
    return bytes_to_string(bytes(s))


b = string_to_bytes(s)
a = string_to_ints(s)
print(f"string             = {s}")
print(f"bytes              = {b}")
print(f"recover from bytes = {bytes_to_string(b)}")
print(f"int array          = {a}")
print(f"recover from int   = {ints_to_string(a)}")



string             = hello, 你好！
bytes              = b'hello, \xe4\xbd\xa0\xe5\xa5\xbd\xef\xbc\x81'
recover from bytes = hello, 你好！
int array          = [104, 101, 108, 108, 111, 44, 32, 228, 189, 160, 229, 165, 189, 239, 188, 129]
recover from int   = hello, 你好！


Now, we have the ability to convert a string into an integer array. The next step is to write the BPE algorithm.

Idea of this algorithm is pretty simple:
- Loop
    - Find the most frequently appeared pair;
    - Assign a new token to the pair.


In [9]:
def bpe(
        a_in: list[int],
        maxit: int = 1,
        DEBUG: bool = False
) -> dict[tuple[int, int]: int]:

    a = a_in[:]
    token_map = dict()
    current_token_max = 255

    for iter in range(maxit):
        n = len(a)

        # compute frequency of consecutive pairs
        freqs = dict()
        for i in range(n-1):
            pair = (a[i], a[i+1])
            if pair in freqs:
                freqs[pair] += 1
            else:
                freqs[pair] = 1

        # compute the most frequent pairs
        max_pair = max(freqs, key=freqs.get)
        current_token_max += 1
        token_map[max_pair] = current_token_max

        i = 0
        n = len(a)
        while (i < n - 1):
            pair = (a[i], a[i+1])
            if pair in token_map:
                a.pop(i)
                a.pop(i)
                a.insert(i, token_map[pair])
                n -= 1
            else:
                i += 1

        if DEBUG:
            print(f"{a = }")
            print(f"{token_map = }")
            print(f"{freqs = }")

    return token_map

The output is a token map, which is `(int, int) -> int`.

In [10]:
def encode(a_in, token_map):
    # make a copy
    a = a_in[:]

    n = len(a)
    i = 0
    while i < n - 1:
        pair = (a[i], a[i+1])
        if pair in token_map:
            a.pop(i)
            a.pop(i)
            a.insert(i, token_map[pair])
            n -= 1
        else:
            i += 1

    return a

In [11]:
def decode(a_in, token_map):
    a = a_in[:]
    n = len(a)
    i = 0
    token_map_inv = {v: k for k, v in token_map.items()}

    while i < n:
        v = a[i]
        if v in token_map_inv:
            pair = token_map_inv[v]
            a.pop(i)
            a.insert(i, pair[1])
            a.insert(i, pair[0])
            n += 1
        else:
            i += 1

    return a

In [16]:
s = 'hello, hello, hola, aloha, cat, bat, hat, mat'
a = string_to_ints(s)
token_map = bpe(a, 10)
print(token_map)
print(a)
print(len(a))

a_encoded = encode(a, token_map)
print(a_encoded)
print(len(a_encoded))

a_restored = decode(a, token_map)
print(a_restored)
print(len(a_restored))


{(44, 32): 256, (97, 116): 257, (108, 111): 258, (256, 104): 259, (101, 108): 260, (260, 258): 261, (261, 259): 262, (97, 256): 263, (257, 256): 264, (104, 262): 265}
[104, 101, 108, 108, 111, 44, 32, 104, 101, 108, 108, 111, 44, 32, 104, 111, 108, 97, 44, 32, 97, 108, 111, 104, 97, 44, 32, 99, 97, 116, 44, 32, 98, 97, 116, 44, 32, 104, 97, 116, 44, 32, 109, 97, 116]
45
[104, 260, 258, 259, 260, 258, 259, 111, 108, 97, 256, 97, 258, 104, 97, 256, 99, 257, 256, 98, 257, 259, 257, 256, 109, 257]
26
[104, 101, 108, 108, 111, 44, 32, 104, 101, 108, 108, 111, 44, 32, 104, 111, 108, 97, 44, 32, 97, 108, 111, 104, 97, 44, 32, 99, 97, 116, 44, 32, 98, 97, 116, 44, 32, 104, 97, 116, 44, 32, 109, 97, 116]
45


Questions
- How to choose the number of iteration?
- Is there any space for optimization?
- What is the current time-space complexity?