### Create byte pair encoding (BPE) for a given text

In [151]:
origin_text ="""Taylor Alison Swift (born December 13, 1989) is an American singer-songwriter. Her artistry, songwriting, and entrepreneurship have influenced the music industry and popular culture. A subject of widespread media coverage, Swift is an advocate of artists' rights and has impacted politics.
Swift began professional songwriting at age 14. She signed with Big Machine Records in 2005 and achieved prominence as a country pop singer with the albums Taylor Swift (2006) and Fearless (2008). Their singles "Teardrops on My Guitar", "Love Story", and "You Belong with Me" were crossover successes on country and pop radio formats and brought Swift mainstream fame. She experimented with rock and electronic styles on her next albums, Speak Now (2010) and Red (2012), respectively, with the latter featuring her first Billboard Hot 100 number-one single, "We Are Never Ever Getting Back Together". Swift recalibrated her image from country to pop with 1989 (2014), a synth-pop album containing the chart-topping songs "Shake It Off", "Blank Space", and "Bad Blood". Media scrutiny inspired the hip-hop-influenced Reputation (2017) and its number-one single "Look What You Made Me Do".
After signing with Republic Records in 2018, Swift released the eclectic pop album Lover (2019) and the autobiographical documentary Miss Americana (2020). She explored indie folk styles on the 2020 albums Folklore and Evermore, subdued electropop on Midnights (2022), and re-recorded four albums subtitled Taylor's Version[a] after a dispute with Big Machine. These albums spawned the number-one songs "Cruel Summer", "Cardigan", "Willow", "Anti-Hero", "All Too Well", and "Is It Over Now?". Her Eras Tour (2023–2024) and its accompanying concert film became the highest-grossing tour and concert film of all time, respectively. Swift has directed videos and films such as Folklore: The Long Pond Studio Sessions (2020) and All Too Well: The Short Film (2021).
One of the world's best-selling musicians, Swift has sold over 200 million records as of 2019. She is the highest-grossing female touring act, the most-streamed woman on Spotify and Apple Music, and the first billionaire with music as the main source of income. Six of her albums have opened with over one million sales in a week. The 2023 Time Person of the Year, Swift has appeared on lists such as Rolling Stone's 100 Greatest Songwriters of All Time, Billboard's Greatest of All Time Artists, and Forbes' World's 100 Most Powerful Women. Her accolades include 14 Grammy Awards, a Primetime Emmy Award, 40 American Music Awards, 40 Billboard Music Awards, and 23 MTV Video Music Awards; she has won the Grammy Award for Album of the Year, the MTV Video Music Award for Video of the Year, and the IFPI Global Recording Artist of the Year a record four times each.'
"""

# convert to bytes
xs = origin_text.encode('utf-8')

origin_text[0:10], xs[0:10], xs[0], xs[1], chr(xs[0]), chr(xs[1])

('Taylor Ali', b'Taylor Ali', 84, 97, 'T', 'a')

In [152]:
xs = xs.decode('utf-8')
xs == origin_xs

True

In [153]:
index_2_token = {i: chr(i) for i in range(256)} # index to token
n_vocab = 276
merged_bytes = {} # index to token
n_extend_token = 276 - len(index_2_token)

In [154]:
index_2_token[97]

'a'

In [155]:
# count frequency of each pair of bytes and get highest frequency pair
from collections import defaultdict

def get_stats(xs, pair_freq=None):
    if pair_freq is None: pair_freq = defaultdict(int)
    for x1, x2 in zip(xs, xs[1:]):
        pair = (x1, x2)
        pair_freq[pair] += 1
    return pair_freq

token = lambda idx: index_2_token[idx]

def merge(xs, pair, new_token_id):
    i = 0
    new_xs = []
    while i < len(xs):
        if (i < len(xs) - 1) and ((xs[i], xs[i+1]) == pair):
            new_xs.append(new_token_id)
            i += 2
        else:
            new_xs.append(xs[i])
            i += 1
    return new_xs

print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))

def train_tokenizer():
    xs = origin_text.encode('utf-8') # convert str to list of index
    xs = list(xs) # convert to list of index
    # print length of xs
    print(len(xs))
    for i in range(n_extend_token):
        next_index = len(index_2_token)
        pair_freq = get_stats(xs) # get common index pair
        _, pair = max([(v, k) for k, v in pair_freq.items()])
        merged_bytes[pair] = next_index
        index_2_token[next_index] = token(pair[0]) + token(pair[1])
        xs = merge(xs, pair, next_index)
    print(len(xs))
    
train_tokenizer()

[5, 6, 99, 9, 1]
2809
2195


In [156]:
# vocab
for byte_pair, merge_idx in merged_bytes.items():
    print(f"index {merge_idx}, byte_pair {byte_pair}, '{index_2_token[byte_pair[0]]}'+'{index_2_token[byte_pair[1]]}'='{index_2_token[merge_idx]}'")

index 256, byte_pair (101, 32), 'e'+' '='e '
index 257, byte_pair (32, 97), ' '+'a'=' a'
index 258, byte_pair (100, 32), 'd'+' '='d '
index 259, byte_pair (115, 32), 's'+' '='s '
index 260, byte_pair (101, 114), 'e'+'r'='er'
index 261, byte_pair (105, 110), 'i'+'n'='in'
index 262, byte_pair (111, 110), 'o'+'n'='on'
index 263, byte_pair (116, 104), 't'+'h'='th'
index 264, byte_pair (257, 110), ' a'+'n'=' an'
index 265, byte_pair (116, 32), 't'+' '='t '
index 266, byte_pair (264, 258), ' an'+'d '=' and '
index 267, byte_pair (44, 32), ','+' '=', '
index 268, byte_pair (263, 256), 'th'+'e '='the '
index 269, byte_pair (114, 101), 'r'+'e'='re'
index 270, byte_pair (50, 48), '2'+'0'='20'
index 271, byte_pair (119, 105), 'w'+'i'='wi'
index 272, byte_pair (97, 114), 'a'+'r'='ar'
index 273, byte_pair (261, 103), 'in'+'g'='ing'
index 274, byte_pair (111, 114), 'o'+'r'='or'
index 275, byte_pair (108, 108), 'l'+'l'='ll'


### Encode

In [157]:
# from list of index to list of index
# find each pair of index and replace with new index


def encode(text):
    xs = list(text.encode('utf-8'))
    while len(xs) >= 2:
        pair_freq = get_stats(xs)
        pair_freq = [(merged_bytes.get(pair, float('inf')), pair) for _, pair in pair_freq]
        idx, pair = min(pair_freq)
        if idx == float('inf'): break
        xs = merge(xs, pair, idx)
    return xs

text = "Hello, world!123"
encoded_text = encode(text)
encoded_text

[72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33, 49, 50, 51]

### Decode

In [158]:
# convert each index into token
def decode(xs):
    return ''.join([index_2_token[x] for x in xs])

decoded_text = decode(encoded_text)
decoded_text

'Hello, world!123'

In [159]:
# test

text = "Marshall Bruce Mathers III, known professionally as Eminem, is an American rapper. He is credited with popularizing hip hop in Middle America and is often regarded as one of the greatest rappers of all time."
encoded_text = encode(text)
decoded_text = decode(encoded_text)
text == decoded_text

True

### Using chatGTP 2 style tokenizer

In [160]:
import regex as re

pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

# This line of code is a regular expression (regex) pattern used to match certain types of strings in a larger body of text.
# Here's a breakdown of what each part of the pattern does:

# - `'s|'t|'re|'ve|'m|'ll|'d`: This matches contractions like 's, 't, 're, 've, 'm, 'll, and 'd. The pipe character `|` means "or", so the pattern matches any one of these contractions.

# - ` ?\p{L}+`: This matches one or more Unicode letters, possibly preceded by a space. `\p{L}` is a Unicode property escape that matches any kind of letter from any language. The `?` means "zero or one of the preceding element", so a space is optional before the letter. The `+` means "one or more of the preceding element".

# - ` ?\p{N}+`: This matches one or more Unicode numbers, possibly preceded by a space. `\p{N}` is a Unicode property escape that matches any kind of numeric character in any script.

# - ` ?[^\s\p{L}\p{N}]+`: This matches one or more characters that are not whitespace, letters, or numbers, possibly preceded by a space. The `^` inside the square brackets negates the character set, so `[^\s\p{L}\p{N}]` matches any character that is not a whitespace character (`\s`), a letter (`\p{L}`), or a number (`\p{N}`).

# - `\s+(?!\S)|\s+`: This matches one or more whitespace characters. The `(?!\S)` is a negative lookahead that asserts that what immediately follows the current position in the string is not a non-whitespace character.

# This pattern is typically used with the `re.findall()` function to split a string into tokens. The tokens can include contractions, words, numbers, punctuation, and whitespace.

text = """for i in range(1, 101):
    if i % 3 == 0 and i % 5 == 0:
        print("FizzBuzz")
    elif i % 3 == 0:
        print("Fizz")
    elif i % 5 == 0:
        print("Buzz")
    else:
        print(i)"""
words = pat.findall(text)

encoded_text = sum([encode(w) for w in words], [])
decoded_text = decode(encoded_text)
assert text == decoded_text
print(decoded_text)

for i in range(1, 101):
    if i % 3 == 0 and i % 5 == 0:
        print("FizzBuzz")
    elif i % 3 == 0:
        print("Fizz")
    elif i % 5 == 0:
        print("Buzz")
    else:
        print(i)


In [161]:
# put all together
token = lambda idx: index_2_token[idx]

class MyTokenizer:
    def __init__(self, vocab_size=276):
        self.vocab_size = vocab_size
        self.index_2_token = {i: chr(i) for i in range(256)} # index to token
        self.merged_bytes = {} # index to token
        self.n_extend_token = self.vocab_size - len(self.index_2_token)

    def token(self, idx): return self.index_2_token[idx]

    def train_tokenizer(self, train_text):
        xs = train_text.encode('utf-8') # convert str to list of index
        xs = list(xs) # convert to list of index
        # print length of xs
        # print(f"original length {len(xs)}")
        for i in range(self.n_extend_token):
            next_index = len(self.index_2_token)
            stats_pair = get_stats(xs) # get common index pair
            _, pair = max([(v, k) for k, v in stats_pair.items()])
            self.merged_bytes[pair] = next_index
            self.index_2_token[next_index] = self.token(pair[0]) + self.token(pair[1])
            xs = merge(xs, pair, next_index)
        # print(f"length after tokenize {len(xs)}")

    def print_vocab(self):
        for byte_pair, merge_idx in self.merged_bytes.items():
            print(f"index {merge_idx}, byte_pair {byte_pair}, '{self.index_2_token[byte_pair[0]]}'+'{self.index_2_token[byte_pair[1]]}'='{self.index_2_token[merge_idx]}'")

    def encode(self, text):
        xs = list(text.encode('utf-8'))
        while len(xs) >= 2:
            pair_freq = get_stats(xs)
            pair_freq = [(self.merged_bytes.get(pair, float('inf')), pair) for _, pair in pair_freq]
            idx, pair = min(pair_freq)
            if idx == float('inf'): break
            xs = merge(xs, pair, idx)
        return xs

    def decode(self, xs):
        return ''.join([self.index_2_token[x] for x in xs])
    


tokenizer = MyTokenizer()

text = "Marshall Bruce Mathers III, known professionally as Eminem, is an American rapper. He is credited with popularizing hip hop in Middle America and is often regarded as one of the greatest rappers of all time."
tokenizer.train_tokenizer(text)

tokenizer.print_vocab()


encoded_text = tokenizer.encode(text)
decoded_text = tokenizer.decode(encoded_text)
assert text == decoded_text

decoded_text
text = """for i in range(1, 101):
    if i % 3 == 0 and i % 5 == 0:
        print("FizzBuzz")
    elif i % 3 == 0:
        print("Fizz")
    elif i % 5 == 0:
        print("Buzz")
    else:
        print(i)"""
encoded_text = tokenizer.encode(text)
decoded_text = tokenizer.decode(encoded_text)
assert text == decoded_text



index 256, byte_pair (115, 32), 's'+' '='s '
index 257, byte_pair (110, 32), 'n'+' '='n '
index 258, byte_pair (101, 114), 'e'+'r'='er'
index 259, byte_pair (101, 32), 'e'+' '='e '
index 260, byte_pair (111, 102), 'o'+'f'='of'
index 261, byte_pair (32, 97), ' '+'a'=' a'
index 262, byte_pair (116, 104), 't'+'h'='th'
index 263, byte_pair (116, 101), 't'+'e'='te'
index 264, byte_pair (114, 101), 'r'+'e'='re'
index 265, byte_pair (108, 108), 'l'+'l'='ll'
index 266, byte_pair (105, 256), 'i'+'s '='is '
index 267, byte_pair (97, 114), 'a'+'r'='ar'
index 268, byte_pair (265, 32), 'll'+' '='ll '
index 269, byte_pair (261, 256), ' a'+'s '=' as '
index 270, byte_pair (258, 256), 'er'+'s '='ers '
index 271, byte_pair (258, 105), 'er'+'i'='eri'
index 272, byte_pair (271, 99), 'eri'+'c'='eric'
index 273, byte_pair (272, 97), 'eric'+'a'='erica'
index 274, byte_pair (114, 97), 'r'+'a'='ra'
index 275, byte_pair (274, 112), 'ra'+'p'='rap'


In [162]:
### Split text before tokenization
import regex as re

pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

words = pat.findall("Hi, I'm a student. I'm learning NLP.")
words

['Hi',
 ',',
 ' I',
 "'m",
 ' a',
 ' student',
 '.',
 ' I',
 "'m",
 ' learning',
 ' NLP',
 '.']

In [163]:
pair = get_stats('ab ab')
pair

defaultdict(int, {('a', 'b'): 2, ('b', ' '): 1, (' ', 'a'): 1})

In [164]:
pair = get_stats('ab')
pair = get_stats(' ab', pair)
pair

defaultdict(int, {('a', 'b'): 2, (' ', 'a'): 1})

In [165]:
# put all together
token = lambda idx: index_2_token[idx]

class MyTokenizer:
    def __init__(self, vocab_size=276):
        self.vocab_size = vocab_size
        self.index_2_token = {i: chr(i) for i in range(256)} # index to token
        self.merged_bytes = {} # index to token
        self.n_extend_token = self.vocab_size - len(self.index_2_token)

    def token(self, idx): return self.index_2_token[idx]

    def train_tokenizer(self, train_text):
        words = pat.findall(train_text)
        xs = [list(w.encode('utf-8')) for w in words] # convert str to list of index
        for i in range(self.n_extend_token):
            next_index = len(self.index_2_token)
            stats_pair = None
            for x in xs: stats_pair = get_stats(x, stats_pair)
            _, pair = max([(v, k) for k, v in stats_pair.items()])
            self.merged_bytes[pair] = next_index
            self.index_2_token[next_index] = self.token(pair[0]) + self.token(pair[1])
            xs = [merge(x, pair, next_index) for x in xs]
        # print(f"length after tokenize {len(xs)}")

    def print_vocab(self):
        for byte_pair, merge_idx in self.merged_bytes.items():
            print(f"index {merge_idx}, byte_pair {byte_pair}, '{self.index_2_token[byte_pair[0]]}'+'{self.index_2_token[byte_pair[1]]}'='{self.index_2_token[merge_idx]}'")

    def _encode(self, word):
        xs = list(word.encode('utf-8'))
        while len(xs) >= 2:
            pair_freq = get_stats(xs)
            pair_freq = [(self.merged_bytes.get(pair, float('inf')), pair) for _, pair in pair_freq]
            idx, pair = min(pair_freq)
            if idx == float('inf'): break
            xs = merge(xs, pair, idx)
        return xs

    def encode(self, text):
        words = pat.findall(text)
        return sum([self._encode(w) for w in words], [])

    def decode(self, xs):
        return ''.join([self.index_2_token[x] for x in xs])
    


tokenizer = MyTokenizer()

text = "Marshall Bruce Mathers III, known professionally as Eminem, is an American rapper. He is credited with popularizing hip hop in Middle America and is often regarded as one of the greatest rappers of all time."
tokenizer.train_tokenizer(text)

tokenizer.print_vocab()


encoded_text = tokenizer.encode(text)
decoded_text = tokenizer.decode(encoded_text)
assert text == decoded_text

# decoded_text
text = """for i in range(1, 101):
    if i % 3 == 0 and i % 5 == 0:
        print("FizzBuzz")
    elif i % 3 == 0:
        print("Fizz")
    elif i % 5 == 0:
        print("Buzz")
    else:
        print(i)"""
encoded_text = tokenizer.encode(text)
decoded_text = tokenizer.decode(encoded_text)
assert text == decoded_text


index 256, byte_pair (101, 114), 'e'+'r'='er'
index 257, byte_pair (32, 97), ' '+'a'=' a'
index 258, byte_pair (111, 102), 'o'+'f'='of'
index 259, byte_pair (32, 105), ' '+'i'=' i'
index 260, byte_pair (259, 115), ' i'+'s'=' is'
index 261, byte_pair (116, 104), 't'+'h'='th'
index 262, byte_pair (116, 101), 't'+'e'='te'
index 263, byte_pair (114, 101), 'r'+'e'='re'
index 264, byte_pair (108, 108), 'l'+'l'='ll'
index 265, byte_pair (97, 114), 'a'+'r'='ar'
index 266, byte_pair (32, 258), ' '+'of'=' of'
index 267, byte_pair (257, 115), ' a'+'s'=' as'
index 268, byte_pair (257, 110), ' a'+'n'=' an'
index 269, byte_pair (256, 115), 'er'+'s'='ers'
index 270, byte_pair (256, 105), 'er'+'i'='eri'
index 271, byte_pair (270, 99), 'eri'+'c'='eric'
index 272, byte_pair (271, 97), 'eric'+'a'='erica'
index 273, byte_pair (114, 97), 'r'+'a'='ra'
index 274, byte_pair (273, 112), 'ra'+'p'='rap'
index 275, byte_pair (274, 112), 'rap'+'p'='rapp'
