<a href="https://colab.research.google.com/github/stkao05/bpe/blob/main/tokenizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def get_stat(tokens, stats=None):
  stats = {} if stats is None else stats
  for pair in zip(tokens, tokens[1:]):
    stats[pair] = stats.get(pair, 0) + 1
  return stats


def merge(tokens, pair, idx):
  tkn = []
  i = 0
  while i < len(tokens):
    if i + 1 < len(tokens) and (tokens[i], tokens[i + 1]) == pair:
      tkn.append(idx)
      i += 2
    else:
      tkn.append(tokens[i])
      i += 1

  return tkn

text = "hello world hello"
tokens = list(text.encode("utf-8"))

merges = {}

src_vocab_size = 256
vocab_size = 259

n = vocab_size - src_vocab_size
idx = src_vocab_size

for _ in range(n):
  stat = get_stat(tokens)
  top_pair = sorted(stat, key=stat.get, reverse=True)[0]
  tokens = merge(tokens, top_pair, idx)
  merges[top_pair] = idx
  idx += 1

tokens, merges

([258, 111, 32, 119, 111, 114, 108, 100, 32, 258, 111],
 {(104, 101): 256, (256, 108): 257, (257, 108): 258})

In [None]:
class BasicTokenizer:

  def train(self, text, vocab_size, verbose=False):
    assert vocab_size > 256

    self.merges = {}
    self.vocab = {i:bytes([i]) for i in range(256)}
    tokens = list(text.encode("utf-8"))
    num_merge = vocab_size - 256

    for i in range(num_merge):
      stat = get_stat(tokens)
      if len(stat) == 0: # single token
        break

      idx = 256 + i
      top_pair = max(stat, key=stat.get)
      tokens = merge(tokens, top_pair, idx)

      p0, p1 = top_pair
      self.merges[top_pair] = idx
      self.vocab[idx] = self.vocab[p0] + self.vocab[p1]

      if verbose:
        print(f"merge {top_pair} -> {idx}")


  def encode(self, text):
    tokens = list(text.encode("utf-8"))
    for pair, idx in self.merges.items():
      tokens = merge(tokens, pair, idx)

    return tokens

  def decode(self, ids):
    bs = b''.join([self.vocab[id] for id in ids])
    txt = bs.decode('utf-8', errors="replace")
    return txt


basic_tkr = BasicTokenizer()
basic_tkr.train("abc! abc! abc! hello hello world steven", 300, verbose=False)
basic_tkr.decode(basic_tkr.encode("abc!"))

'abc!'

In [None]:
# prompt: download text from https://raw.githubusercontent.com/karpathy/minbpe/master/tests/taylorswift.txt into a string var

!curl -s https://raw.githubusercontent.com/karpathy/minbpe/master/tests/taylorswift.txt > taylor_swift.txt
with open("taylor_swift.txt") as f:
  txt = f.read()


In [None]:
tkr = BasicTokenizer()
tkr.train(txt, 300, verbose=False)

In [None]:
utf_tokens = list(txt.encode('utf-8'))
bpe_tokens = tkr.encode(txt)

print("before:", len(utf_tokens))
print("after:", len(bpe_tokens))
print("comprssion", len(utf_tokens) / len(bpe_tokens))

before: 185768
after: 128451
comprssion 1.4462168453340185


# RegexTokenizer

In [None]:
import regex as re

GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

re.findall(GPT4_SPLIT_PATTERN, " hello world!")


[' hello', ' world', '!']

In [None]:
def insert_between(lst, element):
  """
  insert an element in between the list
  ex:
  insert_between([1,2,3], ",") -> [1, ",", 2, ",", 3]
  """
  return [e for i in lst for e in (i, element)][:-1]


def segment(s, spliter):
  """
  split a string into list of string
  partition("a,cd", ",") -> ["a", ",", "cd"]
  """
  ls = s.rsplit(spliter)
  ls = insert_between(ls, spliter)
  return ls


def enc(chunk, special_tokens):
  """
  given a chunk string, return a list of id

  chunk is a string that is either a speical token
  string or a string not does not contain any speical token
  """
  if chunk in special_tokens:
    return [special_tokens[chunk]]
  else:
    return list(chunk.encode("utf-8"))


special_tokens = {
  "<|endoftext|>": 55555,
  "<|startoftext|>": 66666,
}

text = "<|startoftext|>abc<|endoftext|>"
chunks = [text]

for token in special_tokens:
  chunks = [si for s in chunks for si in segment(s, token)]

ids = [id for c in chunks for id in enc(c, special_tokens)]
ids

[66666, 97, 98, 99, 55555]

In [None]:
class RegexTokenizer(BasicTokenizer):

  def __init__(self,
               split_pattern=None,
               special_tokens=None):
    """
    special_tokens: a dict that maps special token string to a id
    ex:
    { "<|endoftext|>": 55555 }
    """
    self.split_pattern = split_pattern
    self.special_tokens = {} if special_tokens == None else special_tokens


  def train(self, text, vocab_size, verbose=False):
    assert vocab_size > 256

    self.merges = {}
    self.vocab = {i:bytes([i]) for i in range(256)}

    text_list = re.findall(GPT4_SPLIT_PATTERN, text)
    ids_list = [list(t.encode("utf-8")) for t in text_list]
    num_merge = vocab_size - 256

    for i in range(num_merge):
      stat = {}
      for ids in ids_list:
        stat = get_stat(ids, stat)

      if len(stat) == 0: break
      idx = 256 + i
      top_pair = max(stat, key=stat.get)
      ids_list = [merge(ids, top_pair, idx) for ids in ids_list]

      p0, p1 = top_pair
      self.merges[top_pair] = idx
      self.vocab[idx] = self.vocab[p0] + self.vocab[p1]

      if verbose:
        print(f"merge ({self.vocab[p0]}, {self.vocab[p1]}) -> {self.vocab[idx]} | {top_pair} -> {idx}")


  def encode(self, text):
    """
    text -> list of int
    """
    # first split text into token or non-token chunk i.e. ["a", "<start>", "b"]
    chunks = [text]
    for token in special_tokens:
      chunks = [si for s in chunks for si in segment(s, token)]

    # encode each chunk. make sure special tokens are encoded [97, 55555, 98]
    ids = [id for c in chunks for id in enc(c, special_tokens)]

    # BPE
    for pair, idx in self.merges.items():
      ids = merge(ids, pair, idx)

    return ids


  def decode(self, ids):
    bs = b''.join([self.vocab[id] for id in ids])
    txt = bs.decode('utf-8', errors="replace")
    return txt


special_tokens = {
  "<|endoftext|>": 55555,
  "<|startoftext|>": 66666,
}

rex_tkr = RegexTokenizer(special_tokens=special_tokens,
                         split_pattern=GPT4_SPLIT_PATTERN)
rex_tkr.train("abc! abc! abc! hello hello world steven", 257, verbose=True)
rex_tkr.encode("<|startoftext|>abc!<|endoftext|>")

merge (b'a', b'b') -> b'ab' | (97, 98) -> 256


[66666, 256, 99, 33, 55555]