## Test Run

In [None]:
text_to_tokenize = """Here is some text to tokenize. It is long and not very usefule but does work as a test"""
tokens = text_to_tokenize.encode("utf-8") # raw bytes
ids = list(map(int, tokens))

In [None]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

In [None]:
# ---
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(ids)
  if (len(stats) > 0):
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

merging (115, 32) into a new token 256
merging (32, 116) into a new token 257
merging (116, 32) into a new token 258
merging (101, 114) into a new token 259
merging (101, 32) into a new token 260
merging (105, 256) into a new token 261
merging (257, 101) into a new token 262
merging (257, 111) into a new token 263
merging (32, 97) into a new token 264
merging (72, 259) into a new token 265
merging (265, 260) into a new token 266
merging (266, 261) into a new token 267
merging (267, 115) into a new token 268
merging (268, 111) into a new token 269
merging (269, 109) into a new token 270
merging (270, 101) into a new token 271
merging (271, 262) into a new token 272
merging (272, 120) into a new token 273
merging (273, 116) into a new token 274
merging (274, 263) into a new token 275


In [None]:
merges

{(115, 32): 256,
 (32, 116): 257,
 (116, 32): 258,
 (101, 114): 259,
 (101, 32): 260,
 (105, 256): 261,
 (257, 101): 262,
 (257, 111): 263,
 (32, 97): 264,
 (72, 259): 265,
 (265, 260): 266,
 (266, 261): 267,
 (267, 115): 268,
 (268, 111): 269,
 (269, 109): 270,
 (270, 101): 271,
 (271, 262): 272,
 (272, 120): 273,
 (273, 116): 274,
 (274, 263): 275}

In [None]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
  # given ids (list of integers), return Python string
  tokens = b"".join(vocab[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

print(decode([128]))

�


In [None]:
def encode(text):
  # given a string, return list of integers (the tokens)
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = get_stats(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break # nothing else can be merged
    idx = merges[pair]
    tokens = merge(tokens, pair, idx)
  return tokens

print(encode(""))

[]


In [None]:
print(decode(encode("hello world")))

hello world


In [None]:
decode([32, 116])

' t'

In [None]:
t = encode("at the water park")
t

[97, 116, 257, 104, 260, 119, 97, 116, 259, 32, 112, 97, 114, 107]

## Full Run - Tokenize Shakespere

In [1]:
# download the TinyShakespeare dataset
!wget -O input.txt https://raw.githubusercontent.com/vvr-rao/my-mini-LLama/main/input_text/input.txt
!mkdir -p input_folder
!mv input.txt input_folder/

# load the dataset
with open('./input_folder/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

--2024-05-25 10:38:46--  https://raw.githubusercontent.com/vvr-rao/my-mini-LLama/main/input_text/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-05-25 10:38:47 (5.53 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
print(len(text))
print(text[:100])

1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [3]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

In [4]:
tokens = text.encode("utf-8") # raw bytes
ids = list(map(int, tokens))

In [5]:
from tqdm import tqdm

# ---
vocab_size = 512 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in tqdm(range(num_merges)):
  stats = get_stats(ids)
  if (len(stats) > 0):
    pair = max(stats, key=stats.get)
    idx = 256 + i
    #print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

100%|██████████| 256/256 [02:00<00:00,  2.12it/s]


In [6]:
#merges

In [7]:
import pickle
!mkdir -p vocab

file_name_merges = f'./vocab/merges.pkl'

with open(file_name_merges, 'wb') as f:
    pickle.dump(merges, f)

In [None]:
len(merges), len(vocab), type(merges), type(vocab)

(40, 296, dict, dict)

In [8]:
#merge the vocabulary and save it
import pickle

vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]


file_name = f'./vocab/vocab.pkl'

with open(file_name, 'wb') as f:
    pickle.dump(vocab, f)



In [None]:
with open(file_name, 'rb') as f:
    vocab2 = pickle.load(f)

In [None]:
vocab2

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [None]:

def decode(ids):
  # given ids (list of integers), return Python string
  tokens = b"".join(vocab2[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

In [None]:
def encode(text):
  # given a string, return list of integers (the tokens)
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = get_stats(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break # nothing else can be merged
    idx = merges[pair]
    tokens = merge(tokens, pair, idx)
  return tokens

In [None]:
decode([105, 259])

'is '

In [None]:
decode(encode("Wherefore art thou Romeo!! and wherefore are the tater tots?"))

'Wherefore art thou Romeo!! and wherefore are the tater tots?'